@@ -36,7 +36,7 @@ def _assert_pytree_isequal(a, b, rtol=None, atol=None):
3636 else :
3737 assert a_elem == b_elem , f"Values are different: { a_elem } != { b_elem } "
3838 except AssertionError as e :
39- failures .append (a_path , str (e ))
39+ failures .append (( a_path , str (e ) ))
4040
4141 if failures :
4242 msg = "\n " .join (f"Path: { path } , Error: { error } " for path , error in failures )
@@ -148,6 +148,39 @@ def f(x, y):
148148 _assert_pytree_isequal (vjp , vjp_raw )
149149
150150
151+ @pytest .mark .parametrize ("use_jit" , [True , False ])
152+ @pytest .mark .parametrize ("jacfun" , [jax .jacfwd , jax .jacrev ])
153+ def test_univariate_tesseract_jacobian (
154+ served_univariate_tesseract_raw , use_jit , jacfun
155+ ):
156+ rosenbrock_tess = Tesseract (served_univariate_tesseract_raw )
157+
158+ # make things callable without keyword args
159+ def f (x , y ):
160+ return apply_tesseract (rosenbrock_tess , inputs = dict (x = x , y = y ))["result" ]
161+
162+ rosenbrock_raw = rosenbrock_impl
163+ if use_jit :
164+ f = jax .jit (f )
165+ rosenbrock_raw = jax .jit (rosenbrock_raw )
166+
167+ x , y = np .array (0.0 ), np .array (0.0 )
168+ jac = jacfun (f , argnums = (0 , 1 ))(x , y )
169+
170+ # Test against Tesseract client
171+ jac_ref = rosenbrock_tess .jacobian (
172+ inputs = dict (x = x , y = y ), jac_inputs = ["x" , "y" ], jac_outputs = ["result" ]
173+ )
174+
175+ # Convert from nested dict to nested tuplw
176+ jac_ref = tuple ((jac_ref ["result" ]["x" ], jac_ref ["result" ]["y" ]))
177+ _assert_pytree_isequal (jac , jac_ref )
178+
179+ # Test against direct implementation
180+ jac_raw = jacfun (rosenbrock_raw , argnums = (0 , 1 ))(x , y )
181+ _assert_pytree_isequal (jac , jac_raw )
182+
183+
151184@pytest .mark .parametrize ("use_jit" , [True , False ])
152185def test_nested_tesseract_apply (served_nested_tesseract_raw , use_jit ):
153186 nested_tess = Tesseract (served_nested_tesseract_raw )
@@ -286,6 +319,55 @@ def f(a, v):
286319 _assert_pytree_isequal (vjp , vjp_ref )
287320
288321
322+ @pytest .mark .parametrize ("use_jit" , [True , False ])
323+ @pytest .mark .parametrize ("jacfun" , [jax .jacfwd , jax .jacrev ])
324+ def test_nested_tesseract_jacobian (served_nested_tesseract_raw , use_jit , jacfun ):
325+ nested_tess = Tesseract (served_nested_tesseract_raw )
326+ a , b = np .array (1.0 , dtype = "float32" ), np .array (2.0 , dtype = "float32" )
327+ v , w = (
328+ np .array ([1.0 , 2.0 , 3.0 ], dtype = "float32" ),
329+ np .array ([5.0 , 7.0 , 9.0 ], dtype = "float32" ),
330+ )
331+
332+ def f (a , v ):
333+ return apply_tesseract (
334+ nested_tess ,
335+ inputs = dict (
336+ scalars = {"a" : a , "b" : b },
337+ vectors = {"v" : v , "w" : w },
338+ other_stuff = {"s" : "hey!" , "i" : 1234 , "f" : 2.718 },
339+ ),
340+ )
341+
342+ if use_jit :
343+ f = jax .jit (f )
344+
345+ jac = jacfun (f , argnums = (0 , 1 ))(a , v )
346+
347+ jac_ref = nested_tess .jacobian (
348+ inputs = dict (
349+ scalars = {"a" : a , "b" : b },
350+ vectors = {"v" : v , "w" : w },
351+ other_stuff = {"s" : "hey!" , "i" : 1234 , "f" : 2.718 },
352+ ),
353+ jac_inputs = ["scalars.a" , "vectors.v" ],
354+ jac_outputs = ["scalars.a" , "vectors.v" ],
355+ )
356+ # JAX returns a 2-layered nested dict containing tuples of arrays
357+ # we need to flatten it to match the Tesseract output (2 layered nested dict of arrays)
358+ jac = {
359+ "scalars.a" : {
360+ "scalars.a" : jac ["scalars" ]["a" ][0 ],
361+ "vectors.v" : jac ["scalars" ]["a" ][1 ],
362+ },
363+ "vectors.v" : {
364+ "scalars.a" : jac ["vectors" ]["v" ][0 ],
365+ "vectors.v" : jac ["vectors" ]["v" ][1 ],
366+ },
367+ }
368+ _assert_pytree_isequal (jac , jac_ref )
369+
370+
289371@pytest .mark .parametrize ("use_jit" , [True , False ])
290372def test_partial_differentiation (served_univariate_tesseract_raw , use_jit ):
291373 """Test that differentiation works correctly in cases where some inputs are constants."""
0 commit comments