@@ -284,3 +284,97 @@ def unflatten(treedef: tree_util.PyTreeDef,
284284 - :func:`jax.tree.structure`
285285 """
286286 return tree_util .tree_unflatten (treedef , leaves )
287+
288+
289+ def flatten_with_path (
290+ tree : Any , is_leaf : Callable [[Any ], bool ] | None = None
291+ ) -> tuple [list [tuple [tree_util .KeyPath , Any ]], tree_util .PyTreeDef ]:
292+ """Flattens a pytree like ``tree_flatten``, but also returns each leaf's key path.
293+
294+ Args:
295+ tree: a pytree to flatten. If it contains a custom type, it is recommended
296+ to be registered with ``register_pytree_with_keys``.
297+
298+ Returns:
299+ A pair which the first element is a list of key-leaf pairs, each of
300+ which contains a leaf and its key path. The second element is a treedef
301+ representing the structure of the flattened tree.
302+
303+ Examples:
304+ >>> import jax
305+ >>> path_vals, treedef = jax.tree.flatten_with_path([1, {'x': 3}])
306+ >>> path_vals
307+ [((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
308+ >>> treedef
309+ PyTreeDef([*, {'x': *}])
310+
311+ See Also:
312+ - :func:`jax.tree.flatten`
313+ - :func:`jax.tree.map_with_path`
314+ - :func:`jax.tree_util.register_pytree_with_keys`
315+ """
316+ return tree_util .tree_flatten_with_path (tree , is_leaf )
317+
318+
319+ def leaves_with_path (
320+ tree : Any , is_leaf : Callable [[Any ], bool ] | None = None
321+ ) -> list [tuple [tree_util .KeyPath , Any ]]:
322+ """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path.
323+
324+ Args:
325+ tree: a pytree. If it contains a custom type, it is recommended to be
326+ registered with ``register_pytree_with_keys``.
327+
328+ Returns:
329+ A list of key-leaf pairs, each of which contains a leaf and its key path.
330+
331+ Examples:
332+ >>> import jax
333+ >>> jax.tree.leaves_with_path([1, {'x': 3}])
334+ [((SequenceKey(idx=0),), 1), ((SequenceKey(idx=1), DictKey(key='x')), 3)]
335+
336+ See Also:
337+ - :func:`jax.tree.leaves`
338+ - :func:`jax.tree.flatten_with_path`
339+ - :func:`jax.tree_util.register_pytree_with_keys`
340+ """
341+ return tree_util .tree_leaves_with_path (tree , is_leaf )
342+
343+
344+ def map_with_path (
345+ f : Callable [..., Any ],
346+ tree : Any ,
347+ * rest : Any ,
348+ is_leaf : Callable [[Any ], bool ] | None = None ,
349+ ) -> Any :
350+ """Maps a multi-input function over pytree key path and args to produce a new pytree.
351+
352+ This is a more powerful alternative of ``tree_map`` that can take the key path
353+ of each leaf as input argument as well.
354+
355+ Args:
356+ f: function that takes ``2 + len(rest)`` arguments, aka. the key path and
357+ each corresponding leaves of the pytrees.
358+ tree: a pytree to be mapped over, with each leaf's key path as the first
359+ positional argument and the leaf itself as the second argument to ``f``.
360+ *rest: a tuple of pytrees, each of which has the same structure as ``tree``
361+ or has ``tree`` as a prefix.
362+
363+ Returns:
364+ A new pytree with the same structure as ``tree`` but with the value at each
365+ leaf given by ``f(kp, x, *xs)`` where ``kp`` is the key path of the leaf at
366+ the corresponding leaf in ``tree``, ``x`` is the leaf value and ``xs`` is
367+ the tuple of values at corresponding nodes in ``rest``.
368+
369+ Examples:
370+ >>> import jax
371+ >>> jax.tree.map_with_path(lambda path, x: x + path[0].idx, [1, 2, 3])
372+ [1, 3, 5]
373+
374+ See Also:
375+ - :func:`jax.tree.map`
376+ - :func:`jax.tree.flatten_with_path`
377+ - :func:`jax.tree.leaves_with_path`
378+ - :func:`jax.tree_util.register_pytree_with_keys`
379+ """
380+ return tree_util .tree_map_with_path (f , tree , * rest , is_leaf = is_leaf )
0 commit comments