Skip to content

Commit 284716a

Browse files
XuehaiPanpytorchmergebot
authored andcommitted
[pytree] add treespec_{leaf,tuple,dict} functions for args_spec modification (pytorch#160843)
The goal of this PR is to provide a standard way to create simple treespec instances and hide the implementation details of the `PyTreeSpec` class. Changes: 1. Add function `treespec_leaf()` to replace `LeafSpec()`. 2. Add function `treespec_tuple(...)` and `treespec_dict(...)` to create treespec for `tuple` / `dict` which is used for `*args` / `**kwargs`. This avoids direct modification to `treespec` instances that rely on the implementation details of the `PyTreeSpec` class. 3. Change `len(spec.children_specs)` to `spec.num_children`. 4. Change `isinstance(spec, LeafSpec)` to `spec.is_leaf()`. ------ Pull Request resolved: pytorch#160843 Approved by: https://github.com/mlazos
1 parent 8b18864 commit 284716a

File tree

22 files changed

+379
-158
lines changed

22 files changed

+379
-158
lines changed

test/export/test_export.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,13 +91,13 @@
9191
from torch.testing._internal.triton_utils import requires_cuda_and_triton, requires_gpu
9292
from torch.testing._internal.two_tensor import TwoTensor
9393
from torch.utils._pytree import (
94-
LeafSpec,
9594
register_constant,
9695
tree_flatten,
9796
tree_map,
9897
tree_unflatten,
9998
TreeSpec,
10099
treespec_dumps,
100+
treespec_leaf,
101101
treespec_loads,
102102
)
103103

@@ -7791,7 +7791,7 @@ class MyDataClass:
77917791

77927792
dt = MyDataClass(x=3, y=4)
77937793
flat, spec = tree_flatten(dt)
7794-
self.assertTrue(spec, LeafSpec())
7794+
self.assertTrue(spec, treespec_leaf())
77957795
self.assertTrue(len(flat) == 1)
77967796

77977797
torch.export.register_dataclass(
@@ -7802,7 +7802,9 @@ class MyDataClass:
78027802
flat, spec = tree_flatten(dt)
78037803
self.assertEqual(
78047804
spec,
7805-
TreeSpec(MyDataClass, [["x", "y"], ["z"]], [LeafSpec(), LeafSpec()]),
7805+
TreeSpec(
7806+
MyDataClass, [["x", "y"], ["z"]], [treespec_leaf(), treespec_leaf()]
7807+
),
78067808
)
78077809
self.assertEqual(flat, [3, 4])
78087810

@@ -7835,7 +7837,7 @@ class MyOtherDataClass: # the pytree registration don't allow registering the s
78357837
TreeSpec(
78367838
MyOtherDataClass,
78377839
[["x", "y", "z"], []],
7838-
[LeafSpec(), LeafSpec(), LeafSpec()],
7840+
[treespec_leaf(), treespec_leaf(), treespec_leaf()],
78397841
),
78407842
)
78417843
self.assertEqual(flat, [3, 4, None])

test/test_pytree.py

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,6 @@ class TestEnum(enum.Enum):
6565
A = auto()
6666

6767

68-
python_leafspec = python_pytree.LeafSpec()
69-
70-
7168
class TestGenericPytree(TestCase):
7269
def test_aligned_public_apis(self):
7370
public_apis = python_pytree.__all__
@@ -197,7 +194,7 @@ def test_flatten_unflatten_leaf(self, pytree):
197194
def run_test_with_leaf(leaf):
198195
values, treespec = pytree.tree_flatten(leaf)
199196
self.assertEqual(values, [leaf])
200-
self.assertEqual(treespec, pytree.LeafSpec())
197+
self.assertEqual(treespec, pytree.treespec_leaf())
201198

202199
unflattened = pytree.tree_unflatten(values, treespec)
203200
self.assertEqual(unflattened, leaf)
@@ -215,7 +212,7 @@ def run_test_with_leaf(leaf):
215212
(
216213
python_pytree,
217214
lambda tup: python_pytree.TreeSpec(
218-
tuple, None, [python_leafspec for _ in tup]
215+
tuple, None, [python_pytree.treespec_leaf() for _ in tup]
219216
),
220217
),
221218
name="python",
@@ -250,7 +247,7 @@ def run_test(tup):
250247
(
251248
python_pytree,
252249
lambda lst: python_pytree.TreeSpec(
253-
list, None, [python_leafspec for _ in lst]
250+
list, None, [python_pytree.treespec_leaf() for _ in lst]
254251
),
255252
),
256253
name="python",
@@ -286,7 +283,7 @@ def run_test(lst):
286283
lambda dct: python_pytree.TreeSpec(
287284
dict,
288285
list(dct.keys()),
289-
[python_leafspec for _ in dct.values()],
286+
[python_pytree.treespec_leaf() for _ in dct.values()],
290287
),
291288
),
292289
name="python",
@@ -327,7 +324,7 @@ def run_test(dct):
327324
lambda odict: python_pytree.TreeSpec(
328325
OrderedDict,
329326
list(odict.keys()),
330-
[python_leafspec for _ in odict.values()],
327+
[python_pytree.treespec_leaf() for _ in odict.values()],
331328
),
332329
),
333330
name="python",
@@ -371,7 +368,7 @@ def run_test(odict):
371368
lambda ddct: python_pytree.TreeSpec(
372369
defaultdict,
373370
[ddct.default_factory, list(ddct.keys())],
374-
[python_leafspec for _ in ddct.values()],
371+
[python_pytree.treespec_leaf() for _ in ddct.values()],
375372
),
376373
),
377374
name="python",
@@ -413,7 +410,7 @@ def run_test(ddct):
413410
(
414411
python_pytree,
415412
lambda deq: python_pytree.TreeSpec(
416-
deque, deq.maxlen, [python_leafspec for _ in deq]
413+
deque, deq.maxlen, [python_pytree.treespec_leaf() for _ in deq]
417414
),
418415
),
419416
name="python",
@@ -453,7 +450,7 @@ def test_flatten_unflatten_namedtuple(self, pytree):
453450
def run_test(tup):
454451
if pytree is python_pytree:
455452
expected_spec = python_pytree.TreeSpec(
456-
namedtuple, Point, [python_leafspec for _ in tup]
453+
namedtuple, Point, [python_pytree.treespec_leaf() for _ in tup]
457454
)
458455
else:
459456
expected_spec = cxx_pytree.tree_structure(Point(0, 1))
@@ -848,16 +845,16 @@ def test_import_pytree_doesnt_import_optree(self):
848845

849846
def test_treespec_equality(self):
850847
self.assertEqual(
851-
python_pytree.LeafSpec(),
852-
python_pytree.LeafSpec(),
848+
python_pytree.treespec_leaf(),
849+
python_pytree.treespec_leaf(),
853850
)
854851
self.assertEqual(
855852
python_pytree.TreeSpec(list, None, []),
856853
python_pytree.TreeSpec(list, None, []),
857854
)
858855
self.assertEqual(
859-
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
860-
python_pytree.TreeSpec(list, None, [python_pytree.LeafSpec()]),
856+
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
857+
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
861858
)
862859
self.assertFalse(
863860
python_pytree.TreeSpec(tuple, None, [])
@@ -892,24 +889,32 @@ def test_treespec_repr(self):
892889
# python_pytree.tree_structure({})
893890
python_pytree.TreeSpec(dict, [], []),
894891
# python_pytree.tree_structure([0])
895-
python_pytree.TreeSpec(list, None, [python_leafspec]),
892+
python_pytree.TreeSpec(list, None, [python_pytree.treespec_leaf()]),
896893
# python_pytree.tree_structure([0, 1])
897894
python_pytree.TreeSpec(
898895
list,
899896
None,
900-
[python_leafspec, python_leafspec],
897+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
901898
),
902899
# python_pytree.tree_structure((0, 1, 2))
903900
python_pytree.TreeSpec(
904901
tuple,
905902
None,
906-
[python_leafspec, python_leafspec, python_leafspec],
903+
[
904+
python_pytree.treespec_leaf(),
905+
python_pytree.treespec_leaf(),
906+
python_pytree.treespec_leaf(),
907+
],
907908
),
908909
# python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
909910
python_pytree.TreeSpec(
910911
dict,
911912
["a", "b", "c"],
912-
[python_leafspec, python_leafspec, python_leafspec],
913+
[
914+
python_pytree.treespec_leaf(),
915+
python_pytree.treespec_leaf(),
916+
python_pytree.treespec_leaf(),
917+
],
913918
),
914919
# python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
915920
python_pytree.TreeSpec(
@@ -919,13 +924,17 @@ def test_treespec_repr(self):
919924
python_pytree.TreeSpec(
920925
tuple,
921926
None,
922-
[python_leafspec, python_leafspec],
927+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
923928
),
924-
python_leafspec,
929+
python_pytree.treespec_leaf(),
925930
python_pytree.TreeSpec(
926931
dict,
927932
["a", "b", "c"],
928-
[python_leafspec, python_leafspec, python_leafspec],
933+
[
934+
python_pytree.treespec_leaf(),
935+
python_pytree.treespec_leaf(),
936+
python_pytree.treespec_leaf(),
937+
],
929938
),
930939
],
931940
),
@@ -938,12 +947,15 @@ def test_treespec_repr(self):
938947
tuple,
939948
None,
940949
[
941-
python_leafspec,
942-
python_leafspec,
950+
python_pytree.treespec_leaf(),
951+
python_pytree.treespec_leaf(),
943952
python_pytree.TreeSpec(
944953
list,
945954
None,
946-
[python_leafspec, python_leafspec],
955+
[
956+
python_pytree.treespec_leaf(),
957+
python_pytree.treespec_leaf(),
958+
],
947959
),
948960
],
949961
),
@@ -957,12 +969,12 @@ def test_treespec_repr(self):
957969
python_pytree.TreeSpec(
958970
list,
959971
None,
960-
[python_leafspec, python_leafspec],
972+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
961973
),
962974
python_pytree.TreeSpec(
963975
list,
964976
None,
965-
[python_leafspec, python_leafspec],
977+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
966978
),
967979
python_pytree.TreeSpec(dict, [], []),
968980
],
@@ -991,7 +1003,7 @@ def test_pytree_serialize_defaultdict_enum(self):
9911003
list,
9921004
None,
9931005
[
994-
python_leafspec,
1006+
python_pytree.treespec_leaf(),
9951007
],
9961008
),
9971009
],
@@ -1000,7 +1012,7 @@ def test_pytree_serialize_defaultdict_enum(self):
10001012
self.assertIsInstance(serialized_spec, str)
10011013

10021014
def test_pytree_serialize_enum(self):
1003-
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_leafspec])
1015+
spec = python_pytree.TreeSpec(dict, TestEnum.A, [python_pytree.treespec_leaf()])
10041016

10051017
serialized_spec = python_pytree.treespec_dumps(spec)
10061018
self.assertIsInstance(serialized_spec, str)
@@ -1163,12 +1175,20 @@ def test_saved_serialized(self):
11631175
OrderedDict,
11641176
[1, 2, 3],
11651177
[
1166-
python_pytree.TreeSpec(tuple, None, [python_leafspec, python_leafspec]),
1167-
python_leafspec,
1178+
python_pytree.TreeSpec(
1179+
tuple,
1180+
None,
1181+
[python_pytree.treespec_leaf(), python_pytree.treespec_leaf()],
1182+
),
1183+
python_pytree.treespec_leaf(),
11681184
python_pytree.TreeSpec(
11691185
dict,
11701186
[4, 5, 6],
1171-
[python_leafspec, python_leafspec, python_leafspec],
1187+
[
1188+
python_pytree.treespec_leaf(),
1189+
python_pytree.treespec_leaf(),
1190+
python_pytree.treespec_leaf(),
1191+
],
11721192
),
11731193
],
11741194
)
@@ -1453,7 +1473,7 @@ def setUp(self):
14531473
raise unittest.SkipTest("C++ pytree tests are not supported in fbcode")
14541474

14551475
def test_treespec_equality(self):
1456-
self.assertEqual(cxx_pytree.LeafSpec(), cxx_pytree.LeafSpec())
1476+
self.assertEqual(cxx_pytree.treespec_leaf(), cxx_pytree.treespec_leaf())
14571477

14581478
def test_treespec_repr(self):
14591479
# Check that it looks sane

0 commit comments

Comments
 (0)