diff --git a/flax/nnx/__init__.py b/flax/nnx/__init__.py index ac0fd2391..d45433b8e 100644 --- a/flax/nnx/__init__.py +++ b/flax/nnx/__init__.py @@ -43,6 +43,7 @@ from .pytreelib import is_data as is_data from .pytreelib import has_data as has_data from .pytreelib import check_pytree as check_pytree +from .pytreelib import prefix from .helpers import Dict as Dict from .helpers import List as List from .helpers import Sequential as Sequential diff --git a/flax/nnx/pytreelib.py b/flax/nnx/pytreelib.py index c53a71314..64756f178 100644 --- a/flax/nnx/pytreelib.py +++ b/flax/nnx/pytreelib.py @@ -38,6 +38,7 @@ ) from flax import config from flax.nnx.variablelib import Variable +from flax.nnx import filterlib from flax.typing import MISSING, Missing, SizeBytes BUILDING_DOCS = 'FLAX_DOC_BUILD' in os.environ @@ -112,6 +113,18 @@ def __init__(self): metadata['static'] = False return dataclasses.field(**kwargs, metadata=metadata) # type: ignore[return-value] +def prefix(pytree, filter_map): + preds = [(filterlib.to_predicate(k), v) for k, v in filter_map.items()] + + def lookup(path, value): + for (pred,obj) in preds: + if pred(graphlib.jax_to_nnx_path(path), value): + return obj + return None + + return jax.tree.map_with_path(lookup, pytree, + is_leaf=lambda p, v: isinstance(v, variablelib.Variable) or lookup(p, v) is not None, + is_leaf_takes_path=True) def register_data_type(type_: T, /) -> T: """Registers a type as pytree data type recognized by Object. @@ -1061,4 +1074,4 @@ def _maybe_int(x): return x def _get_str(x): - return x if isinstance(x, str) else str(x) \ No newline at end of file + return x if isinstance(x, str) else str(x) diff --git a/tests/nnx/rngs_test.py b/tests/nnx/rngs_test.py index 9f01582bd..9500fc081 100644 --- a/tests/nnx/rngs_test.py +++ b/tests/nnx/rngs_test.py @@ -216,6 +216,30 @@ def test_fork_rngs(self, graph): nnx.restore_rngs(backups) self.assertNotEqual(rngs.params.key, new_key) + def test_prefix(self): + class Model(nnx.Module): + def __init__(self, rngs: nnx.Rngs): + self.linear = nnx.Linear(20, 10, rngs=rngs) + self.drop = nnx.Dropout(0.1, rngs=rngs) + def __call__(self, x): + return self.drop(self.linear(x)) + + with nnx.graphlib.set_graph_updates(False): + with nnx.graphlib.set_graph_mode(False): + rngs = nnx.Rngs(0, dropout=jax.random.key(1)) + rngs = rngs.split({f'dropout': 5}) + prefix = nnx.prefix(rngs, {'dropout': 0}) + model = nnx.vmap(Model, in_axes=(prefix,))(rngs) + assert model.drop.rngs.key[...].shape == (5,) + assert model.drop.rngs.count[...].shape == (5,) + bias = model.linear.bias[...] + assert all(jnp.allclose(x,y) for (x,y) in zip(bias, bias[1:])) + + # This is the same as just using 0 as the model in_axes + prefix2 = nnx.prefix(model, {nnx.Variable: 0}) + nnx.vmap(Model.__call__, in_axes=(prefix2,None))(model, jnp.ones(20)) + + def test_random_helpers(self): rngs = nnx.Rngs(0, params=1)