Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 55 additions & 18 deletions tesseract_jax/primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,20 @@

import functools
from collections.abc import Sequence
from typing import Any
from typing import Any, TypeVar

import jax.tree
import numpy as np
from jax import ShapeDtypeStruct, dtypes, extend
from jax.core import ShapedArray
from jax.interpreters import ad, mlir, xla
from jax.interpreters import ad, batching, mlir, xla
from jax.tree_util import PyTreeDef
from jax.typing import ArrayLike
from tesseract_core import Tesseract

from tesseract_jax.tesseract_compat import Jaxeract
from tesseract_jax.tesseract_compat import Jaxeract, combine_args

T = TypeVar("T")

tesseract_dispatch_p = extend.core.Primitive("tesseract_dispatch")
tesseract_dispatch_p.multiple_results = True
Expand All @@ -35,21 +37,13 @@ def __hash__(self) -> int:


def split_args(
flat_args: Sequence[Any], is_static_mask: Sequence[bool]
) -> tuple[tuple[ArrayLike, ...], tuple[_Hashable, ...]]:
"""Split a flat argument list into a tuple (array_args, static_args)."""
static_args = tuple(
_make_hashable(arg)
for arg, is_static in zip(flat_args, is_static_mask, strict=True)
if is_static
)
array_args = tuple(
arg
for arg, is_static in zip(flat_args, is_static_mask, strict=True) # fmt: skip
if not is_static
)

return array_args, static_args
flat_args: Sequence[T], mask: Sequence[bool]
) -> tuple[tuple[T, ...], tuple[T, ...]]:
"""Split a flat argument tuple according to mask (mask_False, mask_True)."""
lists = ([], [])
for a, m in zip(flat_args, mask, strict=True):
lists[m].append(a)
return tuple(tuple(args) for args in lists)


@tesseract_dispatch_p.def_abstract_eval
Expand Down Expand Up @@ -238,6 +232,48 @@ def _dispatch(*args: ArrayLike) -> Any:
mlir.register_lowering(tesseract_dispatch_p, tesseract_dispatch_lowering)


def tesseract_dispatch_batching(
array_args: ArrayLike | ShapedArray | Any,
axes: Sequence[Any],
*,
static_args: tuple[_Hashable, ...],
input_pytreedef: PyTreeDef,
output_pytreedef: PyTreeDef,
output_avals: tuple[ShapeDtypeStruct, ...],
is_static_mask: tuple[bool, ...],
client: Jaxeract,
eval_func: str,
) -> Any:
"""Defines how to dispatch batch operations such as vmap (which is used by jax.jacobian)."""
new_args = [
arg if ax is batching.not_mapped else batching.moveaxis(arg, ax, 0)
for arg, ax in zip(array_args, axes, strict=True)
]

is_batched_mask = [d is not batching.not_mapped for d in axes]
unbatched_args, batched_args = split_args(new_args, is_batched_mask)

def _batch_fun(batched_args: tuple):
combined_args = combine_args(unbatched_args, batched_args, is_batched_mask)
return tesseract_dispatch_p.bind(
*combined_args,
static_args=static_args,
input_pytreedef=input_pytreedef,
output_pytreedef=output_pytreedef,
output_avals=output_avals,
is_static_mask=is_static_mask,
client=client,
eval_func=eval_func,
)

outvals = jax.lax.map(_batch_fun, batched_args)

return tuple(outvals), (0,) * len(outvals)


batching.primitive_batchers[tesseract_dispatch_p] = tesseract_dispatch_batching


def _check_dtype(dtype: Any) -> None:
dt = np.dtype(dtype)
if dtypes.canonicalize_dtype(dt) != dt:
Expand Down Expand Up @@ -318,6 +354,7 @@ def apply_tesseract(
flat_args, input_pytreedef = jax.tree.flatten(inputs)
is_static_mask = tuple(_is_static(arg) for arg in flat_args)
array_args, static_args = split_args(flat_args, is_static_mask)
static_args = tuple(_make_hashable(arg) for arg in static_args)

# Get abstract values for outputs, so we can unflatten them later
output_pytreedef, avals = None, None
Expand Down
32 changes: 16 additions & 16 deletions tesseract_jax/tesseract_compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from collections.abc import Sequence
from typing import Any, TypeAlias

import jax.tree
Expand All @@ -12,6 +13,14 @@
PyTree: TypeAlias = Any


def combine_args(args0: Sequence, args1: Sequence, mask: Sequence[bool]) -> tuple:
"""Merge the elements of two lists based on a mask."""
assert sum(mask) == len(args1) and len(mask) - sum(mask) == len(args0)
args0_iter, args1_iter = iter(args0), iter(args1)
combined_args = [next(args1_iter) if m else next(args0_iter) for m in mask]
return tuple(combined_args)


def unflatten_args(
array_args: tuple[ArrayLike, ...],
static_args: tuple[Any, ...],
Expand All @@ -20,23 +29,14 @@ def unflatten_args(
remove_static_args: bool = False,
) -> PyTree:
"""Unflatten lists of arguments (static and not) into a pytree."""
combined_args = []
static_iter = iter(static_args)
array_iter = iter(array_args)

for is_static in is_static_mask:
if is_static:
elem = next(static_iter)
elem = elem.wrapped if hasattr(elem, "wrapped") else elem

if remove_static_args:
combined_args.append(None)
else:
combined_args.append(elem)

else:
combined_args.append(next(array_iter))
if remove_static_args:
static_args_converted = [None] * len(static_args)
else:
static_args_converted = [
elem.wrapped if hasattr(elem, "wrapped") else elem for elem in static_args
]

combined_args = combine_args(array_args, static_args_converted, is_static_mask)
result = jax.tree.unflatten(input_pytreedef, combined_args)

if remove_static_args:
Expand Down
11 changes: 11 additions & 0 deletions tests/nested_tesseract/tesseract_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,17 @@ def vector_jacobian_product(
return out


def jacobian(inputs: InputSchema, jac_inputs: set[str], jac_outputs: set[str]):
jac = {dy: {dx: [0.0, 0.0, 0.0] for dx in jac_inputs} for dy in jac_outputs}

if "scalars.a" in jac_inputs and "scalars.a" in jac_outputs:
jac["scalars.a"]["scalars.a"] = 10.0
if "vectors.v" in jac_inputs and "vectors.v" in jac_outputs:
jac["vectors.v"]["vectors.v"] = [[10.0, 0, 0], [0, 10.0, 0], [0, 0, 10.0]]

return jac


def abstract_eval(abstract_inputs):
"""Calculate output shape of apply from the shape of its inputs."""
return {
Expand Down
140 changes: 139 additions & 1 deletion tests/test_endtoend.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from functools import partial

import jax
import numpy as np
import pytest
Expand Down Expand Up @@ -36,7 +38,7 @@ def _assert_pytree_isequal(a, b, rtol=None, atol=None):
else:
assert a_elem == b_elem, f"Values are different: {a_elem} != {b_elem}"
except AssertionError as e:
failures.append(a_path, str(e))
failures.append((a_path, str(e)))

if failures:
msg = "\n".join(f"Path: {path}, Error: {error}" for path, error in failures)
Expand Down Expand Up @@ -148,6 +150,90 @@ def f(x, y):
_assert_pytree_isequal(vjp, vjp_raw)


@pytest.mark.parametrize("use_jit", [True, False])
@pytest.mark.parametrize(
"jacfun", [partial(jax.jacfwd, argnums=(0, 1)), partial(jax.jacrev, argnums=(0, 1))]
)
def test_univariate_tesseract_jacobian(
served_univariate_tesseract_raw, use_jit, jacfun
):
rosenbrock_tess = Tesseract(served_univariate_tesseract_raw)

# make things callable without keyword args
@jacfun
def f(x, y):
return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y))["result"]

rosenbrock_raw = jacfun(rosenbrock_impl)
if use_jit:
f = jax.jit(f)
rosenbrock_raw = jax.jit(rosenbrock_raw)

x, y = np.array(0.0), np.array(0.0)
jac = f(x, y)

# Test against Tesseract client
jac_ref = rosenbrock_tess.jacobian(
inputs=dict(x=x, y=y), jac_inputs=["x", "y"], jac_outputs=["result"]
)

# Convert from nested dict to nested tuplw
jac_ref = tuple((jac_ref["result"]["x"], jac_ref["result"]["y"]))
_assert_pytree_isequal(jac, jac_ref)

# Test against direct implementation
jac_raw = rosenbrock_raw(x, y)
_assert_pytree_isequal(jac, jac_raw)


@pytest.mark.parametrize("use_jit", [True, False])
def test_univariate_tesseract_vmap(served_univariate_tesseract_raw, use_jit):
rosenbrock_tess = Tesseract(served_univariate_tesseract_raw)

# make things callable without keyword args
def f(x, y):
return apply_tesseract(rosenbrock_tess, inputs=dict(x=x, y=y))["result"]

# add one batch dimension
for axes in [(0, 0), (0, None), (None, 0)]:
x = np.arange(3) if axes[0] is not None else np.array(0.0)
y = np.arange(3) if axes[1] is not None else np.array(0.0)
f_vmapped = jax.vmap(f, in_axes=axes)
raw_vmapped = jax.vmap(rosenbrock_impl, in_axes=axes)

if use_jit:
f_vmapped = jax.jit(f_vmapped)
raw_vmapped = jax.jit(raw_vmapped)

result = f_vmapped(x, y)
result_raw = raw_vmapped(x, y)

_assert_pytree_isequal(result, result_raw)

# add an additional batch dimension
for extra_dim in [0, 1, -1]:
if axes[0] is not None:
x = np.arange(6).reshape(2, 3)
if axes[1] is not None:
y = np.arange(6).reshape(2, 3)

additional_axes = tuple(
extra_dim if ax is not None else None for ax in axes
)

f_vmappedtwice = jax.vmap(f_vmapped, in_axes=additional_axes)
raw_vmappedtwice = jax.vmap(raw_vmapped, in_axes=additional_axes)

if use_jit:
f_vmappedtwice = jax.jit(f_vmappedtwice)
raw_vmappedtwice = jax.jit(raw_vmappedtwice)

result = f_vmappedtwice(x, y)
result_raw = raw_vmappedtwice(x, y)

_assert_pytree_isequal(result, result_raw)


@pytest.mark.parametrize("use_jit", [True, False])
def test_nested_tesseract_apply(served_nested_tesseract_raw, use_jit):
nested_tess = Tesseract(served_nested_tesseract_raw)
Expand Down Expand Up @@ -286,6 +372,58 @@ def f(a, v):
_assert_pytree_isequal(vjp, vjp_ref)


@pytest.mark.parametrize("use_jit", [True, False])
@pytest.mark.parametrize(
"jacfun", [partial(jax.jacfwd, argnums=(0, 1)), partial(jax.jacrev, argnums=(0, 1))]
)
def test_nested_tesseract_jacobian(served_nested_tesseract_raw, use_jit, jacfun):
nested_tess = Tesseract(served_nested_tesseract_raw)
a, b = np.array(1.0, dtype="float32"), np.array(2.0, dtype="float32")
v, w = (
np.array([1.0, 2.0, 3.0], dtype="float32"),
np.array([5.0, 7.0, 9.0], dtype="float32"),
)

@jacfun
def f(a, v):
return apply_tesseract(
nested_tess,
inputs=dict(
scalars={"a": a, "b": b},
vectors={"v": v, "w": w},
other_stuff={"s": "hey!", "i": 1234, "f": 2.718},
),
)

if use_jit:
f = jax.jit(f)

jac = f(a, v)

jac_ref = nested_tess.jacobian(
inputs=dict(
scalars={"a": a, "b": b},
vectors={"v": v, "w": w},
other_stuff={"s": "hey!", "i": 1234, "f": 2.718},
),
jac_inputs=["scalars.a", "vectors.v"],
jac_outputs=["scalars.a", "vectors.v"],
)
# JAX returns a 2-layered nested dict containing tuples of arrays
# we need to flatten it to match the Tesseract output (2 layered nested dict of arrays)
jac = {
"scalars.a": {
"scalars.a": jac["scalars"]["a"][0],
"vectors.v": jac["scalars"]["a"][1],
},
"vectors.v": {
"scalars.a": jac["vectors"]["v"][0],
"vectors.v": jac["vectors"]["v"][1],
},
}
_assert_pytree_isequal(jac, jac_ref)


@pytest.mark.parametrize("use_jit", [True, False])
def test_partial_differentiation(served_univariate_tesseract_raw, use_jit):
"""Test that differentiation works correctly in cases where some inputs are constants."""
Expand Down