File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed
Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -221,6 +221,8 @@ class Model(nnx.Module):
221221 def __init__ (self , rngs : nnx .Rngs ):
222222 self .linear = nnx .Linear (20 , 10 , rngs = rngs )
223223 self .drop = nnx .Dropout (0.1 , rngs = rngs )
224+ def __call__ (self , x ):
225+ return self .drop (self .linear (x ))
224226
225227 with nnx .graphlib .set_graph_updates (False ):
226228 with nnx .graphlib .set_graph_mode (False ):
@@ -233,6 +235,10 @@ def __init__(self, rngs: nnx.Rngs):
233235 bias = model .linear .bias [...]
234236 assert all (jnp .allclose (x ,y ) for (x ,y ) in zip (bias , bias [1 :]))
235237
238+ # This is the same as just using 0 as the model in_axes
239+ prefix2 = nnx .prefix (model , {nnx .Variable : 0 })
240+ nnx .vmap (Model .__call__ , in_axes = (prefix2 ,None ))(model , jnp .ones (20 ))
241+
236242
237243 def test_random_helpers (self ):
238244 rngs = nnx .Rngs (0 , params = 1 )
You can’t perform that action at this time.
0 commit comments