Skip to content

Commit 47efbc6

Browse files
committed
added jacobian tests
1 parent d9c6786 commit 47efbc6

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

tests/nested_tesseract/tesseract_api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,17 @@ def vector_jacobian_product(
8181
return out
8282

8383

84+
def jacobian(inputs: InputSchema, jac_inputs: set[str], jac_outputs: set[str]):
85+
jac = {dy: {dx: [0.0, 0.0, 0.0] for dx in jac_inputs} for dy in jac_outputs}
86+
87+
if "scalars.a" in jac_inputs and "scalars.a" in jac_outputs:
88+
jac["scalars.a"]["scalars.a"] = 10.0
89+
if "vectors.v" in jac_inputs and "vectors.v" in jac_outputs:
90+
jac["vectors.v"]["vectors.v"] = [[10.0, 0, 0], [0, 10.0, 0], [0, 0, 10.0]]
91+
92+
return jac
93+
94+
8495
def abstract_eval(abstract_inputs):
8596
"""Calculate output shape of apply from the shape of its inputs."""
8697
return {

tests/test_endtoend.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def _assert_pytree_isequal(a, b, rtol=None, atol=None):
3636
else:
3737
assert a_elem == b_elem, f"Values are different: {a_elem} != {b_elem}"
3838
except AssertionError as e:
39-
failures.append(a_path, str(e))
39+
failures.append((a_path, str(e)))
4040

4141
if failures:
4242
msg = "\n".join(f"Path: {path}, Error: {error}" for path, error in failures)
@@ -148,6 +148,39 @@ def f(x, y):
148148
_assert_pytree_isequal(vjp, vjp_raw)
149149

150150

151+
@pytest.mark.parametrize("use_jit", [True, False])
152+
@pytest.mark.parametrize("jacfun", [jax.jacfwd, jax.jacrev])
153+
def test_univariate_tesseract_jacobian(
154+
served_univariate_tesseract_raw, use_jit, jacfun
155+
):
156+
rosenbrock_tess = Tesseract(served_univariate_tesseract_raw)
157+
158+
# make things callable without keyword args
159+
def f(x, y):
160+
return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y))["result"]
161+
162+
rosenbrock_raw = rosenbrock_impl
163+
if use_jit:
164+
f = jax.jit(f)
165+
rosenbrock_raw = jax.jit(rosenbrock_raw)
166+
167+
x, y = np.array(0.0), np.array(0.0)
168+
jac = jacfun(f, argnums=(0, 1))(x, y)
169+
170+
# Test against Tesseract client
171+
jac_ref = rosenbrock_tess.jacobian(
172+
inputs=dict(x=x, y=y), jac_inputs=["x", "y"], jac_outputs=["result"]
173+
)
174+
175+
# Convert from nested dict to nested tuplw
176+
jac_ref = tuple((jac_ref["result"]["x"], jac_ref["result"]["y"]))
177+
_assert_pytree_isequal(jac, jac_ref)
178+
179+
# Test against direct implementation
180+
jac_raw = jacfun(rosenbrock_raw, argnums=(0, 1))(x, y)
181+
_assert_pytree_isequal(jac, jac_raw)
182+
183+
151184
@pytest.mark.parametrize("use_jit", [True, False])
152185
def test_nested_tesseract_apply(served_nested_tesseract_raw, use_jit):
153186
nested_tess = Tesseract(served_nested_tesseract_raw)
@@ -286,6 +319,55 @@ def f(a, v):
286319
_assert_pytree_isequal(vjp, vjp_ref)
287320

288321

322+
@pytest.mark.parametrize("use_jit", [True, False])
323+
@pytest.mark.parametrize("jacfun", [jax.jacfwd, jax.jacrev])
324+
def test_nested_tesseract_jacobian(served_nested_tesseract_raw, use_jit, jacfun):
325+
nested_tess = Tesseract(served_nested_tesseract_raw)
326+
a, b = np.array(1.0, dtype="float32"), np.array(2.0, dtype="float32")
327+
v, w = (
328+
np.array([1.0, 2.0, 3.0], dtype="float32"),
329+
np.array([5.0, 7.0, 9.0], dtype="float32"),
330+
)
331+
332+
def f(a, v):
333+
return apply_tesseract(
334+
nested_tess,
335+
inputs=dict(
336+
scalars={"a": a, "b": b},
337+
vectors={"v": v, "w": w},
338+
other_stuff={"s": "hey!", "i": 1234, "f": 2.718},
339+
),
340+
)
341+
342+
if use_jit:
343+
f = jax.jit(f)
344+
345+
jac = jacfun(f, argnums=(0, 1))(a, v)
346+
347+
jac_ref = nested_tess.jacobian(
348+
inputs=dict(
349+
scalars={"a": a, "b": b},
350+
vectors={"v": v, "w": w},
351+
other_stuff={"s": "hey!", "i": 1234, "f": 2.718},
352+
),
353+
jac_inputs=["scalars.a", "vectors.v"],
354+
jac_outputs=["scalars.a", "vectors.v"],
355+
)
356+
# JAX returns a 2-layered nested dict containing tuples of arrays
357+
# we need to flatten it to match the Tesseract output (2 layered nested dict of arrays)
358+
jac = {
359+
"scalars.a": {
360+
"scalars.a": jac["scalars"]["a"][0],
361+
"vectors.v": jac["scalars"]["a"][1],
362+
},
363+
"vectors.v": {
364+
"scalars.a": jac["vectors"]["v"][0],
365+
"vectors.v": jac["vectors"]["v"][1],
366+
},
367+
}
368+
_assert_pytree_isequal(jac, jac_ref)
369+
370+
289371
@pytest.mark.parametrize("use_jit", [True, False])
290372
def test_partial_differentiation(served_univariate_tesseract_raw, use_jit):
291373
"""Test that differentiation works correctly in cases where some inputs are constants."""

0 commit comments

Comments
 (0)