Skip to content

Commit 5bcb1f5

Browse files
Cristian GarciaFlax Authors
authored andcommitted
Use linen_vars_to_nnx_attrs in ToLinen variable restoration
Refactor ToLinen.__call__ to use linen_vars_to_nnx_attrs for converting Linen variables back to NNX state, replacing manual AxisMetadata unboxing and merge_state. PiperOrigin-RevId: 874740856
1 parent 3ae65af commit 5bcb1f5

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

flax/nnx/bridge/wrappers.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,8 @@
1919
import dataclasses
2020

2121
from flax import linen
22-
from flax import core
2322
from flax import nnx
2423
from flax.core import FrozenDict
25-
from flax.core import meta
2624
from flax.nnx import graph
2725
from flax.nnx import variablelib
2826
from flax.nnx.bridge import variables as bv
@@ -32,7 +30,6 @@
3230
from flax.nnx.pytreelib import Pytree
3331
from flax.nnx.rnglib import Rngs
3432
import jax
35-
from jax import tree_util as jtu
3633

3734
M = tp.TypeVar('M', bound=Module)
3835

@@ -350,19 +347,10 @@ def _module_kwargs():
350347
module = self.nnx_class(*self.args, **_module_kwargs())
351348

352349
# update nnx module from linen variables
353-
def maybe_unbox(x):
354-
if isinstance(x, meta.AxisMetadata):
355-
return x.unbox()
356-
return x
357-
states = jtu.tree_map(
358-
maybe_unbox,
359-
list(core.unfreeze(self.variables).values()), # type: ignore[wrong-arg-types, arg-type]
360-
is_leaf=lambda x: isinstance(x, meta.AxisMetadata),
361-
)
362-
if not states:
363-
states = ({},)
364-
365-
new_state = nnx.merge_state(*states)
350+
if self.variables:
351+
new_state = nnx.State(bv.linen_vars_to_nnx_attrs(self.variables))
352+
else:
353+
new_state = nnx.State({})
366354
new_state_flat = nnx.traversals.flatten_mapping(new_state)
367355
current_state_flat = nnx.traversals.flatten_mapping(nnx.state(module))
368356
unknown_state_flat = {path: v for path, v in new_state_flat.items() if path not in current_state_flat}

0 commit comments

Comments
 (0)