@@ -1089,6 +1089,7 @@ def unflatten( # type: ignore[invalid-annotation]
10891089 index_ref : IndexMap | None = None ,
10901090 outer_index_outer_ref : IndexMap | None = None ,
10911091 copy_variables : bool = False ,
1092+ auto_create_variables : bool = True ,
10921093) -> Node :
10931094 """Unflattens a graphdef into a node with the given state.
10941095
@@ -1150,6 +1151,7 @@ def unflatten( # type: ignore[invalid-annotation]
11501151 index_ref ,
11511152 outer_index_outer_ref ,
11521153 copy_variables ,
1154+ auto_create_variables
11531155 )
11541156
11551157 try :
@@ -1171,6 +1173,7 @@ def _graph_unflatten(
11711173 index_ref : IndexMap ,
11721174 outer_index_outer_ref : IndexMap | None ,
11731175 copy_variables : bool ,
1176+ auto_create_variables : bool
11741177) -> Node :
11751178 """Recursive helper for graph_unflatten.
11761179
@@ -1265,7 +1268,7 @@ def get_mutable_array(array_refdef: ArrayRefDef, leaf):
12651268 variable .set_raw_value (value )
12661269 else : # variabledef.index not in index_ref_cache
12671270 # variable reference does not exist outside, create a new one
1268- if isinstance (value , Variable ):
1271+ if isinstance (value , Variable ) or not auto_create_variables :
12691272 variable = value
12701273 else :
12711274 variable = variabledef .type .from_metadata (
@@ -1314,6 +1317,7 @@ def _get_children() -> list[tuple[Key, tp.Any]]:
13141317 index_ref ,
13151318 outer_index_outer_ref ,
13161319 copy_variables ,
1320+ auto_create_variables
13171321 )
13181322 else :
13191323 raise RuntimeError (f'Unknown node definition: { node_def !r} ' )
@@ -2359,6 +2363,7 @@ def merge( # type: ignore[invalid-annotation]
23592363 / ,
23602364 * states : tp .Any ,
23612365 copy : bool = False ,
2366+ auto_create_variables : bool = True ,
23622367) -> A :
23632368 """The inverse of :func:`flax.nnx.split`.
23642369
@@ -2410,7 +2415,7 @@ def merge( # type: ignore[invalid-annotation]
24102415 _state = state
24112416 else :
24122417 _state = _merge_to_flat_state ((state , * states ))
2413- node = unflatten (graphdef , _state , copy_variables = copy )
2418+ node = unflatten (graphdef , _state , copy_variables = copy , auto_create_variables = auto_create_variables )
24142419 return node
24152420
24162421
@@ -2534,6 +2539,7 @@ def map(
25342539 / ,
25352540 * ,
25362541 graph : bool | None = None ,
2542+ auto_create_variables : bool = True ,
25372543) -> A :
25382544 """Map a function over the state of a graph node.
25392545
@@ -2567,7 +2573,7 @@ def map(
25672573 """
25682574 graphdef , state = split (node , graph = graph )
25692575 state = statelib .map_state (f , state )
2570- return merge (graphdef , state )
2576+ return merge (graphdef , state , auto_create_variables = auto_create_variables )
25712577
25722578
25732579def graphdef (
0 commit comments