Skip to content

Commit 0a684c5

Browse files
committed
Use nnx vmap with graph True
1 parent ae78070 commit 0a684c5

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tests/nnx/prefix_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ def __init__(self, rngs: nnx.Rngs):
1414

1515
rngs = nnx.Rngs(0, dropout=jax.random.key(1))
1616
rngs = rngs.split({f'dropout': 5})
17-
model = jax.vmap(Model, in_axes=(nnx.prefix(rngs, {'dropout': 0}),))(rngs)
17+
model = nnx.vmap(Model, graph=False, graph_updates=False, in_axes=(nnx.prefix(rngs, {'dropout': 0}),))(rngs)
1818
assert model.drop.rngs.key[...].shape == (5,)
1919
assert model.drop.rngs.count[...].shape == (5,)
2020
bias = model.linear.bias[...]
2121
assert all(jnp.allclose(x,y) for (x,y) in zip(bias, bias[1:]))
2222

2323
# Problem: need nnx.vmap to work with prefix
24-
# Currently, vmap flattens to a graphdef first. We must disable this.
24+
# The out_axes argument in nnx.vmap is breaking things.

0 commit comments

Comments
 (0)