Skip to content

Commit 9a1bc4e

Browse files
Cristian GarciaFlax Authors
authored andcommitted
Add tree-mode-only nnx.{vjp,jvp,jit_partial}
PiperOrigin-RevId: 874826112
1 parent 5bcb1f5 commit 9a1bc4e

File tree

10 files changed

+877
-37
lines changed

10 files changed

+877
-37
lines changed

docs_nnx/api_reference/flax.nnx/transforms.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ transforms
1313
.. autofunction:: vmap
1414
.. autofunction:: eval_shape
1515
.. autofunction:: custom_vjp
16+
.. autofunction:: vjp
17+
.. autofunction:: jvp
1618
.. autofunction:: cond
1719
.. autofunction:: switch
1820
.. autofunction:: while_loop

flax/configurations.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Config:
2828
flax_max_repr_depth: int | None
2929
flax_always_shard_variable: bool
3030
flax_hijax_variable: bool
31-
flax_nnx_graph_mode: bool
31+
nnx_graph_mode: bool
3232
# See https://google.github.io/pytype/faq.html.
3333
_HAS_DYNAMIC_ATTRIBUTES = True
3434

@@ -298,7 +298,7 @@ def static_int_env(varname: str, default: int | None) -> int | None:
298298
help='Whether to enable HiJAX support for `nnx.Variable`.',
299299
)
300300
nnx_graph_mode = bool_flag(
301-
name='flax_nnx_graph_mode',
301+
name='nnx_graph_mode',
302302
default=True,
303303
help='Whether NNX APIs default to graph-mode (True) or tree-mode (False).',
304304
)

flax/nnx/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,11 @@
174174
from .transforms.autodiff import grad as grad
175175
from .transforms.autodiff import value_and_grad as value_and_grad
176176
from .transforms.autodiff import custom_vjp as custom_vjp
177+
from .transforms.autodiff import vjp as vjp
178+
from .transforms.autodiff import jvp as jvp
177179
from .transforms.autodiff import remat as remat
178180
from .transforms.compilation import jit as jit
181+
from .transforms.compilation import jit_partial as jit_partial
179182
from .transforms.compilation import shard_map as shard_map
180183
from .transforms.compilation import StateSharding as StateSharding
181184
from .transforms.iteration import Carry as Carry

flax/nnx/graph.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -769,7 +769,7 @@ def flatten( # type: ignore[invalid-annotation]
769769
the overhead of the graph protocol.
770770
"""
771771
if graph is None:
772-
graph = config.flax_nnx_graph_mode
772+
graph = config.nnx_graph_mode
773773
if ref_index is None:
774774
ref_index = RefMap()
775775
leaves: list[tp.Any] = []
@@ -2243,7 +2243,7 @@ def split( # type: ignore[invalid-annotation]
22432243
filters are passed, a single ``State`` is returned.
22442244
"""
22452245
if graph is None:
2246-
graph = config.flax_nnx_graph_mode
2246+
graph = config.nnx_graph_mode
22472247
graphdef, flat_state = flatten(node, graph=graph)
22482248
flat_states = _split_state(flat_state, filters)
22492249
states = _to_nested_state(graphdef, flat_states)
@@ -2451,7 +2451,7 @@ def state(
24512451
One or more :class:`State` mappings.
24522452
"""
24532453
if graph is None:
2454-
graph = config.flax_nnx_graph_mode
2454+
graph = config.nnx_graph_mode
24552455
_, flat_state = flatten(node, graph=graph)
24562456
state = flat_state.to_nested_state()
24572457

@@ -2492,7 +2492,7 @@ def graphdef(
24922492
The :class:`GraphDef` of the :class:`Module` object.
24932493
"""
24942494
if graph is None:
2495-
graph = config.flax_nnx_graph_mode
2495+
graph = config.nnx_graph_mode
24962496
graphdef, _ = flatten(node, graph=graph)
24972497
return graphdef
24982498

@@ -2578,7 +2578,7 @@ def pop(
25782578
return states
25792579

25802580

2581-
def clone(node: Node, variables: bool = True, *, graph: bool = True) -> Node:
2581+
def clone(node: Node, variables: bool = True, *, graph: bool | None = None) -> Node:
25822582
"""Create a deep copy of the given graph node.
25832583
25842584
Example usage::
@@ -2877,7 +2877,7 @@ def iter_graph(
28772877
the overhead of the graph protocol.
28782878
"""
28792879
if graph is None:
2880-
graph = config.flax_nnx_graph_mode
2880+
graph = config.nnx_graph_mode
28812881
if graph:
28822882
return _iter_graph(node)
28832883
else:
@@ -2987,7 +2987,7 @@ def recursive_map(
29872987
the overhead of the graph protocol.
29882988
"""
29892989
if graph is None:
2990-
graph = config.flax_nnx_graph_mode
2990+
graph = config.nnx_graph_mode
29912991
if graph:
29922992
node = clone(node, variables=False)
29932993
path_parts: PathParts = ()

flax/nnx/rnglib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -817,7 +817,7 @@ def split_rngs(
817817
818818
"""
819819
if graph is None:
820-
graph = config.flax_nnx_graph_mode
820+
graph = config.nnx_graph_mode
821821

822822
if isinstance(node, Missing):
823823

0 commit comments

Comments
 (0)