Skip to content

Commit 96e05c5

Browse files
committed
Add flag for variables in nnx.merge
1 parent d747426 commit 96e05c5

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

flax/nnx/graphlib.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,6 +1089,7 @@ def unflatten( # type: ignore[invalid-annotation]
10891089
index_ref: IndexMap | None = None,
10901090
outer_index_outer_ref: IndexMap | None = None,
10911091
copy_variables: bool = False,
1092+
auto_create_variables: bool = True,
10921093
) -> Node:
10931094
"""Unflattens a graphdef into a node with the given state.
10941095
@@ -1150,6 +1151,7 @@ def unflatten( # type: ignore[invalid-annotation]
11501151
index_ref,
11511152
outer_index_outer_ref,
11521153
copy_variables,
1154+
auto_create_variables
11531155
)
11541156

11551157
try:
@@ -1171,6 +1173,7 @@ def _graph_unflatten(
11711173
index_ref: IndexMap,
11721174
outer_index_outer_ref: IndexMap | None,
11731175
copy_variables: bool,
1176+
auto_create_variables: bool
11741177
) -> Node:
11751178
"""Recursive helper for graph_unflatten.
11761179
@@ -1265,7 +1268,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
12651268
variable.set_raw_value(value)
12661269
else: # variabledef.index not in index_ref_cache
12671270
# variable reference does not exist outside, create a new one
1268-
if isinstance(value, Variable):
1271+
if isinstance(value, Variable) or not auto_create_variables:
12691272
variable = value
12701273
else:
12711274
variable = variabledef.type.from_metadata(
@@ -1314,6 +1317,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
13141317
index_ref,
13151318
outer_index_outer_ref,
13161319
copy_variables,
1320+
auto_create_variables
13171321
)
13181322
else:
13191323
raise RuntimeError(f'Unknown node definition: {node_def!r}')
@@ -2359,6 +2363,7 @@ def merge( # type: ignore[invalid-annotation]
23592363
/,
23602364
*states: tp.Any,
23612365
copy: bool = False,
2366+
auto_create_variables: bool = True,
23622367
) -> A:
23632368
"""The inverse of :func:`flax.nnx.split`.
23642369
@@ -2410,7 +2415,7 @@ def merge( # type: ignore[invalid-annotation]
24102415
_state = state
24112416
else:
24122417
_state = _merge_to_flat_state((state, *states))
2413-
node = unflatten(graphdef, _state, copy_variables=copy)
2418+
node = unflatten(graphdef, _state, copy_variables=copy, auto_create_variables=auto_create_variables)
24142419
return node
24152420

24162421

@@ -2534,6 +2539,7 @@ def map(
25342539
/,
25352540
*,
25362541
graph: bool | None = None,
2542+
auto_create_variables: bool = True,
25372543
) -> A:
25382544
"""Map a function over the state of a graph node.
25392545
@@ -2567,7 +2573,7 @@ def map(
25672573
"""
25682574
graphdef, state = split(node, graph=graph)
25692575
state = statelib.map_state(f, state)
2570-
return merge(graphdef, state)
2576+
return merge(graphdef, state, auto_create_variables=auto_create_variables)
25712577

25722578

25732579
def graphdef(

flax/nnx/pytreelib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,4 +1061,4 @@ def _maybe_int(x):
10611061
return x
10621062

10631063
def _get_str(x):
1064-
return x if isinstance(x, str) else str(x)
1064+
return x if isinstance(x, str) else str(x)

tests/nnx/graph_utils_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1699,6 +1699,14 @@ def test_map_replace(self):
16991699
np.testing.assert_array_equal(new_model.kernel[...], jnp.zeros((2, 3)))
17001700
np.testing.assert_array_equal(new_model.bias[...], jnp.zeros((3,)))
17011701

1702+
def test_map_auto_create_variables_false(self):
1703+
rngs = nnx.Rngs(0)
1704+
new_rngs = nnx.map(
1705+
lambda path, x: 0, rngs, auto_create_variables=False
1706+
)
1707+
self.assertNotIsInstance(new_rngs.default.count, nnx.Variable)
1708+
self.assertEqual(new_rngs.default.count, 0)
1709+
17021710

17031711
if __name__ == '__main__':
17041712
absltest.main()

0 commit comments

Comments
 (0)