Skip to content

Commit 0ecf6c4

Browse files
committed
add vmap test for univariate tesseract
1 parent 070c72e commit 0ecf6c4

File tree

1 file changed

+62
-6
lines changed

1 file changed

+62
-6
lines changed

tests/test_endtoend.py

Lines changed: 62 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Copyright 2025 Pasteur Labs. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from functools import partial
5+
46
import jax
57
import numpy as np
68
import pytest
@@ -149,23 +151,26 @@ def f(x, y):
149151

150152

151153
@pytest.mark.parametrize("use_jit", [True, False])
152-
@pytest.mark.parametrize("jacfun", [jax.jacfwd, jax.jacrev])
154+
@pytest.mark.parametrize(
155+
"jacfun", [partial(jax.jacfwd, argnums=(0, 1)), partial(jax.jacrev, argnums=(0, 1))]
156+
)
153157
def test_univariate_tesseract_jacobian(
154158
served_univariate_tesseract_raw, use_jit, jacfun
155159
):
156160
rosenbrock_tess = Tesseract(served_univariate_tesseract_raw)
157161

158162
# make things callable without keyword args
163+
@jacfun
159164
def f(x, y):
160165
return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y))["result"]
161166

162-
rosenbrock_raw = rosenbrock_impl
167+
rosenbrock_raw = jacfun(rosenbrock_impl)
163168
if use_jit:
164169
f = jax.jit(f)
165170
rosenbrock_raw = jax.jit(rosenbrock_raw)
166171

167172
x, y = np.array(0.0), np.array(0.0)
168-
jac = jacfun(f, argnums=(0, 1))(x, y)
173+
jac = f(x, y)
169174

170175
# Test against Tesseract client
171176
jac_ref = rosenbrock_tess.jacobian(
@@ -177,10 +182,58 @@ def f(x, y):
177182
_assert_pytree_isequal(jac, jac_ref)
178183

179184
# Test against direct implementation
180-
jac_raw = jacfun(rosenbrock_raw, argnums=(0, 1))(x, y)
185+
jac_raw = rosenbrock_raw(x, y)
181186
_assert_pytree_isequal(jac, jac_raw)
182187

183188

189+
@pytest.mark.parametrize("use_jit", [True, False])
190+
def test_univariate_tesseract_vmap(served_univariate_tesseract_raw, use_jit):
191+
rosenbrock_tess = Tesseract(served_univariate_tesseract_raw)
192+
193+
# make things callable without keyword args
194+
def f(x, y):
195+
return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y))["result"]
196+
197+
# add one batch dimension
198+
for axes in [(0, 0), (0, None), (None, 0)]:
199+
x = np.arange(3) if axes[0] is not None else np.array(0.0)
200+
y = np.arange(3) if axes[1] is not None else np.array(0.0)
201+
f_vmapped = jax.vmap(f, in_axes=axes)
202+
raw_vmapped = jax.vmap(rosenbrock_impl, in_axes=axes)
203+
204+
if use_jit:
205+
f_vmapped = jax.jit(f_vmapped)
206+
raw_vmapped = jax.jit(raw_vmapped)
207+
208+
result = f_vmapped(x, y)
209+
result_raw = raw_vmapped(x, y)
210+
211+
_assert_pytree_isequal(result, result_raw)
212+
213+
# add an additional batch dimension
214+
for extra_dim in [0, 1, -1]:
215+
if axes[0] is not None:
216+
x = np.arange(6).reshape(2, 3)
217+
if axes[1] is not None:
218+
y = np.arange(6).reshape(2, 3)
219+
220+
additional_axes = tuple(
221+
extra_dim if ax is not None else None for ax in axes
222+
)
223+
224+
f_vmappedtwice = jax.vmap(f_vmapped, in_axes=additional_axes)
225+
raw_vmappedtwice = jax.vmap(raw_vmapped, in_axes=additional_axes)
226+
227+
if use_jit:
228+
f_vmappedtwice = jax.jit(f_vmappedtwice)
229+
raw_vmappedtwice = jax.jit(raw_vmappedtwice)
230+
231+
result = f_vmappedtwice(x, y)
232+
result_raw = raw_vmappedtwice(x, y)
233+
234+
_assert_pytree_isequal(result, result_raw)
235+
236+
184237
@pytest.mark.parametrize("use_jit", [True, False])
185238
def test_nested_tesseract_apply(served_nested_tesseract_raw, use_jit):
186239
nested_tess = Tesseract(served_nested_tesseract_raw)
@@ -320,7 +373,9 @@ def f(a, v):
320373

321374

322375
@pytest.mark.parametrize("use_jit", [True, False])
323-
@pytest.mark.parametrize("jacfun", [jax.jacfwd, jax.jacrev])
376+
@pytest.mark.parametrize(
377+
"jacfun", [partial(jax.jacfwd, argnums=(0, 1)), partial(jax.jacrev, argnums=(0, 1))]
378+
)
324379
def test_nested_tesseract_jacobian(served_nested_tesseract_raw, use_jit, jacfun):
325380
nested_tess = Tesseract(served_nested_tesseract_raw)
326381
a, b = np.array(1.0, dtype="float32"), np.array(2.0, dtype="float32")
@@ -329,6 +384,7 @@ def test_nested_tesseract_jacobian(served_nested_tesseract_raw, use_jit, jacfun)
329384
np.array([5.0, 7.0, 9.0], dtype="float32"),
330385
)
331386

387+
@jacfun
332388
def f(a, v):
333389
return apply_tesseract(
334390
nested_tess,
@@ -342,7 +398,7 @@ def f(a, v):
342398
if use_jit:
343399
f = jax.jit(f)
344400

345-
jac = jacfun(f, argnums=(0, 1))(a, v)
401+
jac = f(a, v)
346402

347403
jac_ref = nested_tess.jacobian(
348404
inputs=dict(

0 commit comments

Comments
 (0)