File tree Expand file tree Collapse file tree 1 file changed +5
-0
lines changed
Expand file tree Collapse file tree 1 file changed +5
-0
lines changed Original file line number Diff line number Diff line change @@ -221,6 +221,8 @@ class Model(nnx.Module):
221221 def __init__ (self , rngs : nnx .Rngs ):
222222 self .linear = nnx .Linear (20 , 10 , rngs = rngs )
223223 self .drop = nnx .Dropout (0.1 , rngs = rngs )
224+ def __call__ (self , x ):
225+ return self .drop (self .linear (x ))
224226
225227 with nnx .graphlib .set_graph_updates (False ):
226228 with nnx .graphlib .set_graph_mode (False ):
@@ -233,6 +235,9 @@ def __init__(self, rngs: nnx.Rngs):
233235 bias = model .linear .bias [...]
234236 assert all (jnp .allclose (x ,y ) for (x ,y ) in zip (bias , bias [1 :]))
235237
238+ prefix2 = nnx .prefix (rngs , {nnx .Param : 0 })
239+ nnx .vmap (Model .__call__ , in_axes = (prefix2 ,None ))(model , jnp .ones (20 ))
240+
236241
237242 def test_random_helpers (self ):
238243 rngs = nnx .Rngs (0 , params = 1 )
You can’t perform that action at this time.
0 commit comments