2222from functools import partial
2323import operator as op
2424import textwrap
25- from typing import Any , NamedTuple , TypeVar , Union , overload
25+ from typing import Any , NamedTuple , TypeVar , overload
2626
2727from jax ._src import traceback_util
2828from jax ._src .lib import pytree
29+ from jax ._src .lib import xla_extension_version
2930from jax ._src .util import safe_zip , set_module
3031from jax ._src .util import unzip2
3132
@@ -209,12 +210,21 @@ def all_leaves(iterable: Iterable[Any],
209210
210211_Children = TypeVar ("_Children" , bound = Iterable [Any ])
211212_AuxData = TypeVar ("_AuxData" , bound = Hashable )
213+ KeyEntry = TypeVar ("KeyEntry" , bound = Any )
214+ KeyLeafPair = tuple [KeyEntry , Any ]
215+ KeyLeafPairs = Iterable [KeyLeafPair ]
216+ KeyPath = tuple [KeyEntry , ...]
212217
213218
214219@export
215- def register_pytree_node (nodetype : type [T ],
216- flatten_func : Callable [[T ], tuple [_Children , _AuxData ]],
217- unflatten_func : Callable [[_AuxData , _Children ], T ]) -> None :
220+ def register_pytree_node (
221+ nodetype : type [T ],
222+ flatten_func : Callable [[T ], tuple [_Children , _AuxData ]],
223+ unflatten_func : Callable [[_AuxData , _Children ], T ],
224+ flatten_with_keys_func : (
225+ Callable [[T ], tuple [KeyLeafPairs , _AuxData ]] | None
226+ ) = None ,
227+ ) -> None :
218228 """Extends the set of types that are considered internal nodes in pytrees.
219229
220230 See :ref:`example usage <pytrees>`.
@@ -279,9 +289,20 @@ def register_pytree_node(nodetype: type[T],
279289 >>> jax.jit(f)(m)
280290 Array([1., 2., 3., 4., 5.], dtype=float32)
281291 """
282- default_registry .register_node (nodetype , flatten_func , unflatten_func )
283- none_leaf_registry .register_node (nodetype , flatten_func , unflatten_func )
284- dispatch_registry .register_node (nodetype , flatten_func , unflatten_func )
292+ if xla_extension_version >= 299 :
293+ default_registry .register_node ( # type: ignore[call-arg]
294+ nodetype , flatten_func , unflatten_func , flatten_with_keys_func
295+ )
296+ none_leaf_registry .register_node ( # type: ignore[call-arg]
297+ nodetype , flatten_func , unflatten_func , flatten_with_keys_func
298+ )
299+ dispatch_registry .register_node ( # type: ignore[call-arg]
300+ nodetype , flatten_func , unflatten_func , flatten_with_keys_func
301+ )
302+ else :
303+ default_registry .register_node (nodetype , flatten_func , unflatten_func )
304+ none_leaf_registry .register_node (nodetype , flatten_func , unflatten_func )
305+ dispatch_registry .register_node (nodetype , flatten_func , unflatten_func )
285306 _registry [nodetype ] = _RegistryEntry (flatten_func , unflatten_func )
286307
287308
@@ -452,21 +473,6 @@ def tree_all(tree: Any, *, is_leaf: Callable[[Any], bool] | None = None) -> bool
452473 return all (tree_leaves (tree , is_leaf = is_leaf ))
453474
454475
455- register_pytree_node (
456- collections .OrderedDict ,
457- lambda x : (tuple (x .values ()), tuple (x .keys ())),
458- lambda keys , values : collections .OrderedDict (safe_zip (keys , values )))
459-
460- def _flatten_defaultdict (d ):
461- keys = tuple (sorted (d ))
462- return tuple (d [k ] for k in keys ), (d .default_factory , keys )
463-
464- register_pytree_node (
465- collections .defaultdict ,
466- _flatten_defaultdict ,
467- lambda s , values : collections .defaultdict (s [0 ], safe_zip (s [1 ], values )))
468-
469-
470476class _HashableCallableShim :
471477 """Object that delegates __call__, __hash__, and __eq__ to another object."""
472478
@@ -578,11 +584,11 @@ def broadcast_prefix(prefix_tree: Any, full_tree: Any,
578584
579585
580586# flatten_one_level is not exported.
581- def flatten_one_level (pytree : Any ) -> tuple [Iterable [Any ], Hashable ]:
587+ def flatten_one_level (tree : Any ) -> tuple [Iterable [Any ], Hashable ]:
582588 """Flatten the given pytree node by one level.
583589
584590 Args:
585- pytree : A valid pytree node, either built-in or registered via
591+ tree : A valid pytree node, either built-in or registered via
586592 :func:`register_pytree_node` or related functions.
587593
588594 Returns:
@@ -601,9 +607,9 @@ def flatten_one_level(pytree: Any) -> tuple[Iterable[Any], Hashable]:
601607 >>> meta
602608 ('a', 'b')
603609 """
604- out = default_registry .flatten_one_level (pytree )
610+ out = default_registry .flatten_one_level (tree )
605611 if out is None :
606- raise ValueError (f"can't tree-flatten type: { type (pytree )} " )
612+ raise ValueError (f"can't tree-flatten type: { type (tree )} " )
607613 else :
608614 return out
609615
@@ -739,10 +745,12 @@ class FlattenedIndexKey():
739745 def __str__ (self ):
740746 return f'[<flat index { self .key } >]'
741747
742- BuiltInKeyEntry = Union [SequenceKey , DictKey , GetAttrKey , FlattenedIndexKey ]
743748
744- KeyEntry = TypeVar ("KeyEntry" , bound = Hashable )
745- KeyPath = tuple [KeyEntry , ...]
749+ if xla_extension_version >= 299 :
750+ SequenceKey = pytree .SequenceKey # type: ignore
751+ DictKey = pytree .DictKey # type: ignore
752+ GetAttrKey = pytree .GetAttrKey # type: ignore
753+ FlattenedIndexKey = pytree .FlattenedIndexKey # type: ignore
746754
747755
748756@export
@@ -764,6 +772,7 @@ def keystr(keys: KeyPath):
764772 return '' .join (map (str , keys ))
765773
766774
775+ # TODO(ivyzheng): remove this after _child_keys() also moved to C++.
767776class _RegistryWithKeypathsEntry (NamedTuple ):
768777 flatten_with_keys : Callable [..., Any ]
769778 unflatten_func : Callable [..., Any ]
@@ -780,7 +789,6 @@ def flatten_with_keys(xs):
780789 flatten_with_keys , _registry [ty ].from_iter
781790 )
782791
783-
784792_registry_with_keypaths : dict [type [Any ], _RegistryWithKeypathsEntry ] = {}
785793
786794_register_keypaths (
@@ -803,13 +811,9 @@ def flatten_with_keys(xs):
803811@export
804812def register_pytree_with_keys (
805813 nodetype : type [T ],
806- flatten_with_keys : Callable [
807- [T ], tuple [Iterable [tuple [KeyEntry , Any ]], _AuxData ]
808- ],
814+ flatten_with_keys : Callable [[T ], tuple [Iterable [KeyLeafPair ], _AuxData ]],
809815 unflatten_func : Callable [[_AuxData , Iterable [Any ]], T ],
810- flatten_func : None | (
811- Callable [[T ], tuple [Iterable [Any ], _AuxData ]]
812- ) = None ,
816+ flatten_func : None | (Callable [[T ], tuple [Iterable [Any ], _AuxData ]]) = None ,
813817):
814818 """Extends the set of types that are considered internal nodes in pytrees.
815819
@@ -870,7 +874,9 @@ def flatten_func_impl(tree):
870874 return [c for _ , c in key_children ], treedef
871875 flatten_func = flatten_func_impl
872876
873- register_pytree_node (nodetype , flatten_func , unflatten_func )
877+ register_pytree_node (
878+ nodetype , flatten_func , unflatten_func , flatten_with_keys
879+ )
874880 _registry_with_keypaths [nodetype ] = _RegistryWithKeypathsEntry (
875881 flatten_with_keys , unflatten_func
876882 )
@@ -1092,6 +1098,40 @@ def flatten_func(x):
10921098 return nodetype
10931099
10941100
1101+ if xla_extension_version >= 299 :
1102+ register_pytree_with_keys (
1103+ collections .OrderedDict ,
1104+ lambda x : (tuple ((DictKey (k ), x [k ]) for k in x .keys ()), tuple (x .keys ())),
1105+ lambda keys , values : collections .OrderedDict (safe_zip (keys , values )),
1106+ )
1107+
1108+ def _flatten_defaultdict_with_keys (d ):
1109+ keys = tuple (sorted (d ))
1110+ return tuple ((DictKey (k ), d [k ]) for k in keys ), (d .default_factory , keys )
1111+
1112+ register_pytree_with_keys (
1113+ collections .defaultdict ,
1114+ _flatten_defaultdict_with_keys ,
1115+ lambda s , values : collections .defaultdict (s [0 ], safe_zip (s [1 ], values )),
1116+ )
1117+ else :
1118+ register_pytree_node (
1119+ collections .OrderedDict ,
1120+ lambda x : (tuple (x .values ()), tuple (x .keys ())),
1121+ lambda keys , values : collections .OrderedDict (safe_zip (keys , values )),
1122+ )
1123+
1124+ def _flatten_defaultdict (d ):
1125+ keys = tuple (sorted (d ))
1126+ return tuple (d [k ] for k in keys ), (d .default_factory , keys )
1127+
1128+ register_pytree_node (
1129+ collections .defaultdict ,
1130+ _flatten_defaultdict ,
1131+ lambda s , values : collections .defaultdict (s [0 ], safe_zip (s [1 ], values )),
1132+ )
1133+
1134+
10951135@export
10961136def register_static (cls : type [H ]) -> type [H ]:
10971137 """Registers `cls` as a pytree with no leaves.
@@ -1144,6 +1184,8 @@ def tree_flatten_with_path(
11441184 which contains a leaf and its key path. The second element is a treedef
11451185 representing the structure of the flattened tree.
11461186 """
1187+ if xla_extension_version >= 299 :
1188+ return default_registry .flatten_with_path (tree , is_leaf )
11471189 _ , tree_def = tree_flatten (tree , is_leaf )
11481190 return _generate_key_paths (tree , is_leaf ), tree_def
11491191
@@ -1164,13 +1206,15 @@ def tree_leaves_with_path(
11641206 - :func:`jax.tree_util.tree_leaves`
11651207 - :func:`jax.tree_util.tree_flatten_with_path`
11661208 """
1167- return _generate_key_paths (tree , is_leaf )
1209+ return tree_flatten_with_path (tree , is_leaf )[ 0 ]
11681210
11691211
11701212# generate_key_paths is not exported.
11711213def generate_key_paths (
11721214 tree : Any , is_leaf : Callable [[Any ], bool ] | None = None
11731215) -> list [tuple [KeyPath , Any ]]:
1216+ if xla_extension_version >= 299 :
1217+ return tree_leaves_with_path (tree , is_leaf )
11741218 return list (_generate_key_paths_ ((), tree , is_leaf ))
11751219_generate_key_paths = generate_key_paths # alias for backward compat
11761220
0 commit comments