Skip to content

Commit ef06607

Browse files
IvyZXGoogle-ML-Automation
authored andcommitted
Implement flatten one level with keys in C++ and use it for the prefix/equality error printing.
With this, we should be able to safely delete the python with-path registry after a new jaxlib release. Also changed all `std::string_view` to `absl::string_view` per requirements of TF repository. PiperOrigin-RevId: 705669465
1 parent eb3ea98 commit ef06607

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

jax/_src/tree_util.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from jax._src import traceback_util
2727
from jax._src.lib import pytree
28+
from jax._src.lib import xla_extension_version
2829
from jax._src.util import safe_zip, set_module
2930
from jax._src.util import unzip2
3031

@@ -607,6 +608,18 @@ def flatten_one_level(tree: Any) -> tuple[Iterable[Any], Hashable]:
607608
return out
608609

609610

611+
# flatten_one_level_with_keys is not exported.
612+
def flatten_one_level_with_keys(
613+
tree: Any,
614+
) -> tuple[Iterable[KeyLeafPair], Hashable]:
615+
"""Flatten the given pytree node by one level, with keys."""
616+
out = default_registry.flatten_one_level_with_keys(tree)
617+
if out is None:
618+
raise ValueError(f"can't tree-flatten type: {type(tree)}")
619+
else:
620+
return out
621+
622+
610623
# prefix_errors is not exported
611624
def prefix_errors(prefix_tree: Any, full_tree: Any,
612625
is_leaf: Callable[[Any], bool] | None = None,
@@ -728,7 +741,7 @@ def keystr(keys: KeyPath):
728741
return ''.join(map(str, keys))
729742

730743

731-
# TODO(ivyzheng): remove this after _child_keys() also moved to C++.
744+
# TODO(ivyzheng): remove this after another jaxlib release.
732745
class _RegistryWithKeypathsEntry(NamedTuple):
733746
flatten_with_keys: Callable[..., Any]
734747
unflatten_func: Callable[..., Any]
@@ -1146,6 +1159,8 @@ def tree_map_with_path(f: Callable[..., Any],
11461159

11471160
def _child_keys(pytree: Any) -> KeyPath:
11481161
assert not treedef_is_strict_leaf(tree_structure(pytree))
1162+
if xla_extension_version >= 301:
1163+
return tuple(k for k, _ in flatten_one_level_with_keys(pytree)[0])
11491164
handler = _registry_with_keypaths.get(type(pytree))
11501165
if handler:
11511166
return tuple(k for k, _ in handler.flatten_with_keys(pytree)[0])

0 commit comments

Comments
 (0)