Skip to content

Commit 894cf76

Browse files
committed
iterate over out axes when vmapped twice
1 parent 0ecf6c4 commit 894cf76

File tree

1 file changed

+16
-11
lines changed

1 file changed

+16
-11
lines changed

tests/test_endtoend.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -221,17 +221,22 @@ def f(x, y):
221221
extra_dim if ax is not None else None for ax in axes
222222
)
223223

224-
f_vmappedtwice = jax.vmap(f_vmapped, in_axes=additional_axes)
225-
raw_vmappedtwice = jax.vmap(raw_vmapped, in_axes=additional_axes)
226-
227-
if use_jit:
228-
f_vmappedtwice = jax.jit(f_vmappedtwice)
229-
raw_vmappedtwice = jax.jit(raw_vmappedtwice)
230-
231-
result = f_vmappedtwice(x, y)
232-
result_raw = raw_vmappedtwice(x, y)
233-
234-
_assert_pytree_isequal(result, result_raw)
224+
for out_axis in [0, 1, -1]:
225+
f_vmappedtwice = jax.vmap(
226+
f_vmapped, in_axes=additional_axes, out_axes=out_axis
227+
)
228+
raw_vmappedtwice = jax.vmap(
229+
raw_vmapped, in_axes=additional_axes, out_axes=out_axis
230+
)
231+
232+
if use_jit:
233+
f_vmappedtwice = jax.jit(f_vmappedtwice)
234+
raw_vmappedtwice = jax.jit(raw_vmappedtwice)
235+
236+
result = f_vmappedtwice(x, y)
237+
result_raw = raw_vmappedtwice(x, y)
238+
239+
_assert_pytree_isequal(result, result_raw)
235240

236241

237242
@pytest.mark.parametrize("use_jit", [True, False])

0 commit comments

Comments
 (0)