Skip to content

Commit d71ee38

Browse files
committed
Add prefix test when calling
1 parent 2196674 commit d71ee38

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

tests/nnx/rngs_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,8 @@ class Model(nnx.Module):
221221
def __init__(self, rngs: nnx.Rngs):
222222
self.linear = nnx.Linear(20, 10, rngs=rngs)
223223
self.drop = nnx.Dropout(0.1, rngs=rngs)
224+
def __call__(self, x):
225+
return self.drop(self.linear(x))
224226

225227
with nnx.graphlib.set_graph_updates(False):
226228
with nnx.graphlib.set_graph_mode(False):
@@ -233,6 +235,10 @@ def __init__(self, rngs: nnx.Rngs):
233235
bias = model.linear.bias[...]
234236
assert all(jnp.allclose(x,y) for (x,y) in zip(bias, bias[1:]))
235237

238+
# This is the same as just using 0 as the model in_axes
239+
prefix2 = nnx.prefix(model, {nnx.Variable: 0})
240+
nnx.vmap(Model.__call__, in_axes=(prefix2,None))(model, jnp.ones(20))
241+
236242

237243
def test_random_helpers(self):
238244
rngs = nnx.Rngs(0, params=1)

0 commit comments

Comments
 (0)