11# Copyright 2025 Pasteur Labs. All Rights Reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4+ from functools import partial
5+
46import jax
57import numpy as np
68import pytest
@@ -149,23 +151,26 @@ def f(x, y):
149151
150152
151153@pytest .mark .parametrize ("use_jit" , [True , False ])
152- @pytest .mark .parametrize ("jacfun" , [jax .jacfwd , jax .jacrev ])
154+ @pytest .mark .parametrize (
155+ "jacfun" , [partial (jax .jacfwd , argnums = (0 , 1 )), partial (jax .jacrev , argnums = (0 , 1 ))]
156+ )
153157def test_univariate_tesseract_jacobian (
154158 served_univariate_tesseract_raw , use_jit , jacfun
155159):
156160 rosenbrock_tess = Tesseract (served_univariate_tesseract_raw )
157161
158162 # make things callable without keyword args
163+ @jacfun
159164 def f (x , y ):
160165 return apply_tesseract (rosenbrock_tess , inputs = dict (x = x , y = y ))["result" ]
161166
162- rosenbrock_raw = rosenbrock_impl
167+ rosenbrock_raw = jacfun ( rosenbrock_impl )
163168 if use_jit :
164169 f = jax .jit (f )
165170 rosenbrock_raw = jax .jit (rosenbrock_raw )
166171
167172 x , y = np .array (0.0 ), np .array (0.0 )
168- jac = jacfun ( f , argnums = ( 0 , 1 )) (x , y )
173+ jac = f (x , y )
169174
170175 # Test against Tesseract client
171176 jac_ref = rosenbrock_tess .jacobian (
@@ -177,10 +182,58 @@ def f(x, y):
177182 _assert_pytree_isequal (jac , jac_ref )
178183
179184 # Test against direct implementation
180- jac_raw = jacfun ( rosenbrock_raw , argnums = ( 0 , 1 )) (x , y )
185+ jac_raw = rosenbrock_raw (x , y )
181186 _assert_pytree_isequal (jac , jac_raw )
182187
183188
189+ @pytest .mark .parametrize ("use_jit" , [True , False ])
190+ def test_univariate_tesseract_vmap (served_univariate_tesseract_raw , use_jit ):
191+ rosenbrock_tess = Tesseract (served_univariate_tesseract_raw )
192+
193+ # make things callable without keyword args
194+ def f (x , y ):
195+ return apply_tesseract (rosenbrock_tess , inputs = dict (x = x , y = y ))["result" ]
196+
197+ # add one batch dimension
198+ for axes in [(0 , 0 ), (0 , None ), (None , 0 )]:
199+ x = np .arange (3 ) if axes [0 ] is not None else np .array (0.0 )
200+ y = np .arange (3 ) if axes [1 ] is not None else np .array (0.0 )
201+ f_vmapped = jax .vmap (f , in_axes = axes )
202+ raw_vmapped = jax .vmap (rosenbrock_impl , in_axes = axes )
203+
204+ if use_jit :
205+ f_vmapped = jax .jit (f_vmapped )
206+ raw_vmapped = jax .jit (raw_vmapped )
207+
208+ result = f_vmapped (x , y )
209+ result_raw = raw_vmapped (x , y )
210+
211+ _assert_pytree_isequal (result , result_raw )
212+
213+ # add an additional batch dimension
214+ for extra_dim in [0 , 1 , - 1 ]:
215+ if axes [0 ] is not None :
216+ x = np .arange (6 ).reshape (2 , 3 )
217+ if axes [1 ] is not None :
218+ y = np .arange (6 ).reshape (2 , 3 )
219+
220+ additional_axes = tuple (
221+ extra_dim if ax is not None else None for ax in axes
222+ )
223+
224+ f_vmappedtwice = jax .vmap (f_vmapped , in_axes = additional_axes )
225+ raw_vmappedtwice = jax .vmap (raw_vmapped , in_axes = additional_axes )
226+
227+ if use_jit :
228+ f_vmappedtwice = jax .jit (f_vmappedtwice )
229+ raw_vmappedtwice = jax .jit (raw_vmappedtwice )
230+
231+ result = f_vmappedtwice (x , y )
232+ result_raw = raw_vmappedtwice (x , y )
233+
234+ _assert_pytree_isequal (result , result_raw )
235+
236+
184237@pytest .mark .parametrize ("use_jit" , [True , False ])
185238def test_nested_tesseract_apply (served_nested_tesseract_raw , use_jit ):
186239 nested_tess = Tesseract (served_nested_tesseract_raw )
@@ -320,7 +373,9 @@ def f(a, v):
320373
321374
322375@pytest .mark .parametrize ("use_jit" , [True , False ])
323- @pytest .mark .parametrize ("jacfun" , [jax .jacfwd , jax .jacrev ])
376+ @pytest .mark .parametrize (
377+ "jacfun" , [partial (jax .jacfwd , argnums = (0 , 1 )), partial (jax .jacrev , argnums = (0 , 1 ))]
378+ )
324379def test_nested_tesseract_jacobian (served_nested_tesseract_raw , use_jit , jacfun ):
325380 nested_tess = Tesseract (served_nested_tesseract_raw )
326381 a , b = np .array (1.0 , dtype = "float32" ), np .array (2.0 , dtype = "float32" )
@@ -329,6 +384,7 @@ def test_nested_tesseract_jacobian(served_nested_tesseract_raw, use_jit, jacfun)
329384 np .array ([5.0 , 7.0 , 9.0 ], dtype = "float32" ),
330385 )
331386
387+ @jacfun
332388 def f (a , v ):
333389 return apply_tesseract (
334390 nested_tess ,
@@ -342,7 +398,7 @@ def f(a, v):
342398 if use_jit :
343399 f = jax .jit (f )
344400
345- jac = jacfun ( f , argnums = ( 0 , 1 )) (a , v )
401+ jac = f (a , v )
346402
347403 jac_ref = nested_tess .jacobian (
348404 inputs = dict (
0 commit comments