Skip to content

Commit a8246ea

Browse files
hawkinspjax authors
authored andcommitted
Issue a warning where code relies on a bug where treedef.flatten_up_to(...) was overly permissive for None treedefs.
For example, tree_map(..., None, [2, 3]) previously did not raise an error, but None is a container and only leaves can be considered tree prefixes in this case. In a future release of JAX, this behavior will become an error. PiperOrigin-RevId: 641690427
1 parent 14d87d3 commit a8246ea

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,13 @@ Remember to align the itemized text with the first line of an item within a list
5656
(https://github.com/openxla/xla/pull/13301).
5757
* Fixes a compiler crash on GPU (https://github.com/google/jax/issues/21396).
5858

59+
* Deprecations
60+
* `jax.tree.map(f, None, non-None)` now emits a `DeprecationWarning`, and will
61+
raise an error in a future version of jax. `None` is only a tree-prefix of
62+
itself. To preserve the current behavior, you can ask `jax.tree.map` to
63+
treat `None` as a leaf value by writing:
64+
`jax.tree.map(lambda x, y: None if x is None else f(x, y), a, b, is_leaf=lambda x: x is None)`.
65+
5966
## jax 0.4.28 (May 9, 2024)
6067

6168
* Bug fixes

0 commit comments

Comments
 (0)