@@ -65,9 +65,6 @@ class TestEnum(enum.Enum):
6565 A = auto ()
6666
6767
68- python_leafspec = python_pytree .LeafSpec ()
69-
70-
7168class TestGenericPytree (TestCase ):
7269 def test_aligned_public_apis (self ):
7370 public_apis = python_pytree .__all__
@@ -197,7 +194,7 @@ def test_flatten_unflatten_leaf(self, pytree):
197194 def run_test_with_leaf (leaf ):
198195 values , treespec = pytree .tree_flatten (leaf )
199196 self .assertEqual (values , [leaf ])
200- self .assertEqual (treespec , pytree .LeafSpec ())
197+ self .assertEqual (treespec , pytree .treespec_leaf ())
201198
202199 unflattened = pytree .tree_unflatten (values , treespec )
203200 self .assertEqual (unflattened , leaf )
@@ -215,7 +212,7 @@ def run_test_with_leaf(leaf):
215212 (
216213 python_pytree ,
217214 lambda tup : python_pytree .TreeSpec (
218- tuple , None , [python_leafspec for _ in tup ]
215+ tuple , None , [python_pytree . treespec_leaf () for _ in tup ]
219216 ),
220217 ),
221218 name = "python" ,
@@ -250,7 +247,7 @@ def run_test(tup):
250247 (
251248 python_pytree ,
252249 lambda lst : python_pytree .TreeSpec (
253- list , None , [python_leafspec for _ in lst ]
250+ list , None , [python_pytree . treespec_leaf () for _ in lst ]
254251 ),
255252 ),
256253 name = "python" ,
@@ -286,7 +283,7 @@ def run_test(lst):
286283 lambda dct : python_pytree .TreeSpec (
287284 dict ,
288285 list (dct .keys ()),
289- [python_leafspec for _ in dct .values ()],
286+ [python_pytree . treespec_leaf () for _ in dct .values ()],
290287 ),
291288 ),
292289 name = "python" ,
@@ -327,7 +324,7 @@ def run_test(dct):
327324 lambda odict : python_pytree .TreeSpec (
328325 OrderedDict ,
329326 list (odict .keys ()),
330- [python_leafspec for _ in odict .values ()],
327+ [python_pytree . treespec_leaf () for _ in odict .values ()],
331328 ),
332329 ),
333330 name = "python" ,
@@ -371,7 +368,7 @@ def run_test(odict):
371368 lambda ddct : python_pytree .TreeSpec (
372369 defaultdict ,
373370 [ddct .default_factory , list (ddct .keys ())],
374- [python_leafspec for _ in ddct .values ()],
371+ [python_pytree . treespec_leaf () for _ in ddct .values ()],
375372 ),
376373 ),
377374 name = "python" ,
@@ -413,7 +410,7 @@ def run_test(ddct):
413410 (
414411 python_pytree ,
415412 lambda deq : python_pytree .TreeSpec (
416- deque , deq .maxlen , [python_leafspec for _ in deq ]
413+ deque , deq .maxlen , [python_pytree . treespec_leaf () for _ in deq ]
417414 ),
418415 ),
419416 name = "python" ,
@@ -453,7 +450,7 @@ def test_flatten_unflatten_namedtuple(self, pytree):
453450 def run_test (tup ):
454451 if pytree is python_pytree :
455452 expected_spec = python_pytree .TreeSpec (
456- namedtuple , Point , [python_leafspec for _ in tup ]
453+ namedtuple , Point , [python_pytree . treespec_leaf () for _ in tup ]
457454 )
458455 else :
459456 expected_spec = cxx_pytree .tree_structure (Point (0 , 1 ))
@@ -848,16 +845,16 @@ def test_import_pytree_doesnt_import_optree(self):
848845
849846 def test_treespec_equality (self ):
850847 self .assertEqual (
851- python_pytree .LeafSpec (),
852- python_pytree .LeafSpec (),
848+ python_pytree .treespec_leaf (),
849+ python_pytree .treespec_leaf (),
853850 )
854851 self .assertEqual (
855852 python_pytree .TreeSpec (list , None , []),
856853 python_pytree .TreeSpec (list , None , []),
857854 )
858855 self .assertEqual (
859- python_pytree .TreeSpec (list , None , [python_pytree .LeafSpec ()]),
860- python_pytree .TreeSpec (list , None , [python_pytree .LeafSpec ()]),
856+ python_pytree .TreeSpec (list , None , [python_pytree .treespec_leaf ()]),
857+ python_pytree .TreeSpec (list , None , [python_pytree .treespec_leaf ()]),
861858 )
862859 self .assertFalse (
863860 python_pytree .TreeSpec (tuple , None , [])
@@ -892,24 +889,32 @@ def test_treespec_repr(self):
892889 # python_pytree.tree_structure({})
893890 python_pytree .TreeSpec (dict , [], []),
894891 # python_pytree.tree_structure([0])
895- python_pytree .TreeSpec (list , None , [python_leafspec ]),
892+ python_pytree .TreeSpec (list , None , [python_pytree . treespec_leaf () ]),
896893 # python_pytree.tree_structure([0, 1])
897894 python_pytree .TreeSpec (
898895 list ,
899896 None ,
900- [python_leafspec , python_leafspec ],
897+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
901898 ),
902899 # python_pytree.tree_structure((0, 1, 2))
903900 python_pytree .TreeSpec (
904901 tuple ,
905902 None ,
906- [python_leafspec , python_leafspec , python_leafspec ],
903+ [
904+ python_pytree .treespec_leaf (),
905+ python_pytree .treespec_leaf (),
906+ python_pytree .treespec_leaf (),
907+ ],
907908 ),
908909 # python_pytree.tree_structure({"a": 0, "b": 1, "c": 2})
909910 python_pytree .TreeSpec (
910911 dict ,
911912 ["a" , "b" , "c" ],
912- [python_leafspec , python_leafspec , python_leafspec ],
913+ [
914+ python_pytree .treespec_leaf (),
915+ python_pytree .treespec_leaf (),
916+ python_pytree .treespec_leaf (),
917+ ],
913918 ),
914919 # python_pytree.tree_structure(OrderedDict([("a", (0, 1)), ("b", 2), ("c", {"a": 3, "b": 4, "c": 5})])
915920 python_pytree .TreeSpec (
@@ -919,13 +924,17 @@ def test_treespec_repr(self):
919924 python_pytree .TreeSpec (
920925 tuple ,
921926 None ,
922- [python_leafspec , python_leafspec ],
927+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
923928 ),
924- python_leafspec ,
929+ python_pytree . treespec_leaf () ,
925930 python_pytree .TreeSpec (
926931 dict ,
927932 ["a" , "b" , "c" ],
928- [python_leafspec , python_leafspec , python_leafspec ],
933+ [
934+ python_pytree .treespec_leaf (),
935+ python_pytree .treespec_leaf (),
936+ python_pytree .treespec_leaf (),
937+ ],
929938 ),
930939 ],
931940 ),
@@ -938,12 +947,15 @@ def test_treespec_repr(self):
938947 tuple ,
939948 None ,
940949 [
941- python_leafspec ,
942- python_leafspec ,
950+ python_pytree . treespec_leaf () ,
951+ python_pytree . treespec_leaf () ,
943952 python_pytree .TreeSpec (
944953 list ,
945954 None ,
946- [python_leafspec , python_leafspec ],
955+ [
956+ python_pytree .treespec_leaf (),
957+ python_pytree .treespec_leaf (),
958+ ],
947959 ),
948960 ],
949961 ),
@@ -957,12 +969,12 @@ def test_treespec_repr(self):
957969 python_pytree .TreeSpec (
958970 list ,
959971 None ,
960- [python_leafspec , python_leafspec ],
972+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
961973 ),
962974 python_pytree .TreeSpec (
963975 list ,
964976 None ,
965- [python_leafspec , python_leafspec ],
977+ [python_pytree . treespec_leaf (), python_pytree . treespec_leaf () ],
966978 ),
967979 python_pytree .TreeSpec (dict , [], []),
968980 ],
@@ -991,7 +1003,7 @@ def test_pytree_serialize_defaultdict_enum(self):
9911003 list ,
9921004 None ,
9931005 [
994- python_leafspec ,
1006+ python_pytree . treespec_leaf () ,
9951007 ],
9961008 ),
9971009 ],
@@ -1000,7 +1012,7 @@ def test_pytree_serialize_defaultdict_enum(self):
10001012 self .assertIsInstance (serialized_spec , str )
10011013
10021014 def test_pytree_serialize_enum (self ):
1003- spec = python_pytree .TreeSpec (dict , TestEnum .A , [python_leafspec ])
1015+ spec = python_pytree .TreeSpec (dict , TestEnum .A , [python_pytree . treespec_leaf () ])
10041016
10051017 serialized_spec = python_pytree .treespec_dumps (spec )
10061018 self .assertIsInstance (serialized_spec , str )
@@ -1163,12 +1175,20 @@ def test_saved_serialized(self):
11631175 OrderedDict ,
11641176 [1 , 2 , 3 ],
11651177 [
1166- python_pytree .TreeSpec (tuple , None , [python_leafspec , python_leafspec ]),
1167- python_leafspec ,
1178+ python_pytree .TreeSpec (
1179+ tuple ,
1180+ None ,
1181+ [python_pytree .treespec_leaf (), python_pytree .treespec_leaf ()],
1182+ ),
1183+ python_pytree .treespec_leaf (),
11681184 python_pytree .TreeSpec (
11691185 dict ,
11701186 [4 , 5 , 6 ],
1171- [python_leafspec , python_leafspec , python_leafspec ],
1187+ [
1188+ python_pytree .treespec_leaf (),
1189+ python_pytree .treespec_leaf (),
1190+ python_pytree .treespec_leaf (),
1191+ ],
11721192 ),
11731193 ],
11741194 )
@@ -1453,7 +1473,7 @@ def setUp(self):
14531473 raise unittest .SkipTest ("C++ pytree tests are not supported in fbcode" )
14541474
14551475 def test_treespec_equality (self ):
1456- self .assertEqual (cxx_pytree .LeafSpec (), cxx_pytree .LeafSpec ())
1476+ self .assertEqual (cxx_pytree .treespec_leaf (), cxx_pytree .treespec_leaf ())
14571477
14581478 def test_treespec_repr (self ):
14591479 # Check that it looks sane
0 commit comments