|
25 | 25 |
|
26 | 26 | from jax._src import traceback_util |
27 | 27 | from jax._src.lib import pytree |
| 28 | +from jax._src.lib import xla_extension_version |
28 | 29 | from jax._src.util import safe_zip, set_module |
29 | 30 | from jax._src.util import unzip2 |
30 | 31 |
|
@@ -607,6 +608,18 @@ def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]: |
607 | 608 | return out |
608 | 609 |
|
609 | 610 |
|
| 611 | +# flatten_one_level_with_keys is not exported. |
| 612 | +def flatten_one_level_with_keys( |
| 613 | + tree: Any, |
| 614 | +) -> tuple[Iterable[KeyLeafPair], Hashable]: |
| 615 | + """Flatten the given pytree node by one level, with keys.""" |
| 616 | + out = default_registry.flatten_one_level_with_keys(tree) |
| 617 | + if out is None: |
| 618 | + raise ValueError(f"can't tree-flatten type: {type(tree)}") |
| 619 | + else: |
| 620 | + return out |
| 621 | + |
| 622 | + |
610 | 623 | # prefix_errors is not exported |
611 | 624 | def prefix_errors(prefix_tree: Any, full_tree: Any, |
612 | 625 | is_leaf: Callable[[Any], bool] | None = None, |
@@ -728,7 +741,7 @@ def keystr(keys: KeyPath): |
728 | 741 | return ''.join(map(str, keys)) |
729 | 742 |
|
730 | 743 |
|
731 | | -# TODO(ivyzheng): remove this after _child_keys() also moved to C++. |
| 744 | +# TODO(ivyzheng): remove this after another jaxlib release. |
732 | 745 | class _RegistryWithKeypathsEntry(NamedTuple): |
733 | 746 | flatten_with_keys: Callable[..., Any] |
734 | 747 | unflatten_func: Callable[..., Any] |
@@ -1146,6 +1159,8 @@ def tree_map_with_path(f: Callable[..., Any], |
1146 | 1159 |
|
1147 | 1160 | def _child_keys(pytree: Any) -> KeyPath: |
1148 | 1161 | assert not treedef_is_strict_leaf(tree_structure(pytree)) |
| 1162 | + if xla_extension_version >= 301: |
| 1163 | + return tuple(k for k, _ in flatten_one_level_with_keys(pytree)[0]) |
1149 | 1164 | handler = _registry_with_keypaths.get(type(pytree)) |
1150 | 1165 | if handler: |
1151 | 1166 | return tuple(k for k, _ in handler.flatten_with_keys(pytree)[0]) |
|
0 commit comments