@@ -221,17 +221,22 @@ def f(x, y):
221221 extra_dim if ax is not None else None for ax in axes
222222 )
223223
224- f_vmappedtwice = jax .vmap (f_vmapped , in_axes = additional_axes )
225- raw_vmappedtwice = jax .vmap (raw_vmapped , in_axes = additional_axes )
226-
227- if use_jit :
228- f_vmappedtwice = jax .jit (f_vmappedtwice )
229- raw_vmappedtwice = jax .jit (raw_vmappedtwice )
230-
231- result = f_vmappedtwice (x , y )
232- result_raw = raw_vmappedtwice (x , y )
233-
234- _assert_pytree_isequal (result , result_raw )
224+ for out_axis in [0 , 1 , - 1 ]:
225+ f_vmappedtwice = jax .vmap (
226+ f_vmapped , in_axes = additional_axes , out_axes = out_axis
227+ )
228+ raw_vmappedtwice = jax .vmap (
229+ raw_vmapped , in_axes = additional_axes , out_axes = out_axis
230+ )
231+
232+ if use_jit :
233+ f_vmappedtwice = jax .jit (f_vmappedtwice )
234+ raw_vmappedtwice = jax .jit (raw_vmappedtwice )
235+
236+ result = f_vmappedtwice (x , y )
237+ result_raw = raw_vmappedtwice (x , y )
238+
239+ _assert_pytree_isequal (result , result_raw )
235240
236241
237242@pytest .mark .parametrize ("use_jit" , [True , False ])
0 commit comments