Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flax/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from .pytreelib import is_data as is_data
from .pytreelib import has_data as has_data
from .pytreelib import check_pytree as check_pytree
from .pytreelib import prefix
from .helpers import Dict as Dict
from .helpers import List as List
from .helpers import Sequential as Sequential
Expand Down
15 changes: 14 additions & 1 deletion flax/nnx/pytreelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from flax import config
from flax.nnx.variablelib import Variable
from flax.nnx import filterlib
from flax.typing import MISSING, Missing, SizeBytes

BUILDING_DOCS = 'FLAX_DOC_BUILD' in os.environ
Expand Down Expand Up @@ -112,6 +113,18 @@ def __init__(self):
metadata['static'] = False
return dataclasses.field(**kwargs, metadata=metadata) # type: ignore[return-value]

def prefix(pytree, filter_map):
preds = [(filterlib.to_predicate(k), v) for k, v in filter_map.items()]

def lookup(path, value):
for (pred,obj) in preds:
if pred(graphlib.jax_to_nnx_path(path), value):
return obj
return None

return jax.tree.map_with_path(lookup, pytree,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

recently added nnx.map which can be used here, else you have to convert from the jax path format to the nnx path format

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would require that the prefix filter only select Variable nodes. Which is probably what we usually want, but a little less flexible than using jax.tree.map_with_path as I have here.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get a strange error if I try to switch to nnx.map:

E     At that key path, the prefix pytree vmap out_axes has a subtree of type
E         <class 'flax.nnx.rnglib.RngKey'>
E     but at the same key path the full pytree has a subtree of different type
E         <class 'flax.nnx.extract.Mask'>.

What's nnx.extract.Mask? I guess I have some digging to do.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I'm just converting the jax path format to the nnx one, which makes the tests pass. But I'll investigate why nnx.map produces different behavior.

is_leaf=lambda p, v: isinstance(v, variablelib.Variable) or lookup(p, v) is not None,
is_leaf_takes_path=True)

def register_data_type(type_: T, /) -> T:
"""Registers a type as pytree data type recognized by Object.
Expand Down Expand Up @@ -1061,4 +1074,4 @@ def _maybe_int(x):
return x

def _get_str(x):
return x if isinstance(x, str) else str(x)
return x if isinstance(x, str) else str(x)
24 changes: 24 additions & 0 deletions tests/nnx/rngs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,30 @@ def test_fork_rngs(self, graph):
nnx.restore_rngs(backups)
self.assertNotEqual(rngs.params.key, new_key)

def test_prefix(self):
class Model(nnx.Module):
def __init__(self, rngs: nnx.Rngs):
self.linear = nnx.Linear(20, 10, rngs=rngs)
self.drop = nnx.Dropout(0.1, rngs=rngs)
def __call__(self, x):
return self.drop(self.linear(x))

with nnx.graphlib.set_graph_updates(False):
with nnx.graphlib.set_graph_mode(False):
rngs = nnx.Rngs(0, dropout=jax.random.key(1))
rngs = rngs.split({f'dropout': 5})
prefix = nnx.prefix(rngs, {'dropout': 0})
model = nnx.vmap(Model, in_axes=(prefix,))(rngs)
assert model.drop.rngs.key[...].shape == (5,)
assert model.drop.rngs.count[...].shape == (5,)
bias = model.linear.bias[...]
assert all(jnp.allclose(x,y) for (x,y) in zip(bias, bias[1:]))

# This is the same as just using 0 as the model in_axes
prefix2 = nnx.prefix(model, {nnx.Variable: 0})
nnx.vmap(Model.__call__, in_axes=(prefix2,None))(model, jnp.ones(20))


def test_random_helpers(self):
rngs = nnx.Rngs(0, params=1)

Expand Down
Loading