Skip to content

Commit 2196674

Browse files
committed
Add nnx.prefix
1 parent a138d9f commit 2196674

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

flax/nnx/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .pytreelib import is_data as is_data
4444
from .pytreelib import has_data as has_data
4545
from .pytreelib import check_pytree as check_pytree
46+
from .pytreelib import prefix
4647
from .helpers import Dict as Dict
4748
from .helpers import List as List
4849
from .helpers import Sequential as Sequential

flax/nnx/pytreelib.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from flax import config
4040
from flax.nnx.variablelib import Variable
41+
from flax.nnx import filterlib
4142
from flax.typing import MISSING, Missing, SizeBytes
4243

4344
BUILDING_DOCS = 'FLAX_DOC_BUILD' in os.environ
@@ -112,6 +113,18 @@ def __init__(self):
112113
metadata['static'] = False
113114
return dataclasses.field(**kwargs, metadata=metadata) # type: ignore[return-value]
114115

116+
def prefix(pytree, filter_map):
117+
preds = [(filterlib.to_predicate(k), v) for k, v in filter_map.items()]
118+
119+
def lookup(path, value):
120+
for (pred,obj) in preds:
121+
if pred(graphlib.jax_to_nnx_path(path), value):
122+
return obj
123+
return None
124+
125+
return jax.tree.map_with_path(lookup, pytree,
126+
is_leaf=lambda p, v: isinstance(v, variablelib.Variable) or lookup(p, v) is not None,
127+
is_leaf_takes_path=True)
115128

116129
def register_data_type(type_: T, /) -> T:
117130
"""Registers a type as pytree data type recognized by Object.
@@ -1061,4 +1074,4 @@ def _maybe_int(x):
10611074
return x
10621075

10631076
def _get_str(x):
1064-
return x if isinstance(x, str) else str(x)
1077+
return x if isinstance(x, str) else str(x)

tests/nnx/rngs_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,24 @@ def test_fork_rngs(self, graph):
216216
nnx.restore_rngs(backups)
217217
self.assertNotEqual(rngs.params.key, new_key)
218218

219+
def test_prefix(self):
220+
class Model(nnx.Module):
221+
def __init__(self, rngs: nnx.Rngs):
222+
self.linear = nnx.Linear(20, 10, rngs=rngs)
223+
self.drop = nnx.Dropout(0.1, rngs=rngs)
224+
225+
with nnx.graphlib.set_graph_updates(False):
226+
with nnx.graphlib.set_graph_mode(False):
227+
rngs = nnx.Rngs(0, dropout=jax.random.key(1))
228+
rngs = rngs.split({f'dropout': 5})
229+
prefix = nnx.prefix(rngs, {'dropout': 0})
230+
model = nnx.vmap(Model, in_axes=(prefix,))(rngs)
231+
assert model.drop.rngs.key[...].shape == (5,)
232+
assert model.drop.rngs.count[...].shape == (5,)
233+
bias = model.linear.bias[...]
234+
assert all(jnp.allclose(x,y) for (x,y) in zip(bias, bias[1:]))
235+
236+
219237
def test_random_helpers(self):
220238
rngs = nnx.Rngs(0, params=1)
221239

0 commit comments

Comments
 (0)