Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 108 additions & 41 deletions tesseract_jax/primitive.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import functools
from collections.abc import Sequence
from typing import Any, TypeVar

import jax.core as jc
import jax.numpy as jnp
import jax.tree
import numpy as np
from jax import ShapeDtypeStruct, dtypes, extend
from jax.core import ShapedArray
from jax.interpreters import ad, batching, mlir, xla
from jax.interpreters import ad, batching, mlir
from jax.tree_util import PyTreeDef
from jax.typing import ArrayLike
from tesseract_core import Tesseract
Expand All @@ -21,9 +21,6 @@

tesseract_dispatch_p = extend.core.Primitive("tesseract_dispatch")
tesseract_dispatch_p.multiple_results = True
tesseract_dispatch_p.def_impl(
functools.partial(xla.apply_primitive, tesseract_dispatch_p)
)


class _Hashable:
Expand Down Expand Up @@ -184,12 +181,48 @@ def tesseract_dispatch_transpose_rule(
# I see it chokes on map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out),
# where eqn.invars ends up being longer than cts_out.

return tuple([None] * len(args) + vjp)
return tuple([None] * len(args) + list(vjp))


ad.primitive_transposes[tesseract_dispatch_p] = tesseract_dispatch_transpose_rule


def tesseract_dispatch(
*array_args: ArrayLike | ShapedArray | Any,
static_args: tuple[_Hashable, ...],
input_pytreedef: PyTreeDef,
output_pytreedef: PyTreeDef | None,
output_avals: tuple[ShapeDtypeStruct, ...] | None,
is_static_mask: tuple[bool, ...],
client: Jaxeract,
eval_func: str,
) -> Any:
"""Defines how to dispatch lowering the computation.

The dispatch that is not lowered is only called in cases where abstract eval is not needed.
"""

def _dispatch(*args: ArrayLike) -> Any:
static_args_ = tuple(_unpack_hashable(arg) for arg in static_args)
out = getattr(client, eval_func)(
args,
static_args_,
input_pytreedef,
output_pytreedef,
output_avals,
is_static_mask,
)
if not isinstance(out, tuple) and output_avals is not None:
out = (out,)
return out

result = _dispatch(*array_args)
return result


tesseract_dispatch_p.def_impl(tesseract_dispatch)


def tesseract_dispatch_lowering(
ctx: Any,
*array_args: ArrayLike | ShapedArray | Any,
Expand Down Expand Up @@ -233,6 +266,9 @@ def _dispatch(*args: ArrayLike) -> Any:
mlir.register_lowering(tesseract_dispatch_p, tesseract_dispatch_lowering)


mlir.register_lowering(tesseract_dispatch_p, tesseract_dispatch_lowering)


def tesseract_dispatch_batching(
array_args: ArrayLike | ShapedArray | Any,
axes: Sequence[Any],
Expand Down Expand Up @@ -344,53 +380,84 @@ def apply_tesseract(
f"Got {type(tesseract_client)} instead."
)

if "abstract_eval" not in tesseract_client.available_endpoints:
transformation = False

# determine if any array in the input pytree is a tracer
inputs_flat, _ = jax.tree.flatten(inputs)
for inp in inputs_flat:
if isinstance(inp, jc.Tracer):
transformation = True
break

if transformation and "abstract_eval" not in tesseract_client.available_endpoints:
raise ValueError(
"Given Tesseract object does not support abstract_eval, "
"which is required for compatibility with JAX."
"it is however called in combination with a JAX transformation "
"like jit, grad, vmap, or pmap. "
"Either remove the transformation or add an abstract_eval endpoint "
"to the Tesseract object."
)

client = Jaxeract(tesseract_client)

# Get abstract values for outputs, so we can unflatten them later

flat_args, input_pytreedef = jax.tree.flatten(inputs)
is_static_mask = tuple(_is_static(arg) for arg in flat_args)
array_args, static_args = split_args(flat_args, is_static_mask)
static_args = tuple(_make_hashable(arg) for arg in static_args)

# Get abstract values for outputs, so we can unflatten them later
output_pytreedef, avals = None, None
avals = client.abstract_eval(
array_args,
static_args,
input_pytreedef,
output_pytreedef,
avals,
is_static_mask,
)
if "abstract_eval" in tesseract_client.available_endpoints:
# Get abstract values for outputs, so we can unflatten them later
output_pytreedef, avals = None, None
avals = client.abstract_eval(
array_args,
static_args,
input_pytreedef,
output_pytreedef,
avals,
is_static_mask,
)

is_aval = lambda x: isinstance(x, dict) and "dtype" in x and "shape" in x
flat_avals, output_pytreedef = jax.tree.flatten(avals, is_leaf=is_aval)
for aval in flat_avals:
if not is_aval(aval):
continue
_check_dtype(aval["dtype"])
is_aval = lambda x: isinstance(x, dict) and "dtype" in x and "shape" in x
flat_avals, output_pytreedef = jax.tree.flatten(avals, is_leaf=is_aval)
for aval in flat_avals:
if not is_aval(aval):
continue
_check_dtype(aval["dtype"])

flat_avals = tuple(
jax.ShapeDtypeStruct(shape=tuple(aval["shape"]), dtype=aval["dtype"])
for aval in flat_avals
)
flat_avals = tuple(
jax.ShapeDtypeStruct(shape=tuple(aval["shape"]), dtype=aval["dtype"])
for aval in flat_avals
)

# Apply the primitive
out = tesseract_dispatch_p.bind(
*array_args,
static_args=static_args,
input_pytreedef=input_pytreedef,
output_pytreedef=output_pytreedef,
output_avals=flat_avals,
is_static_mask=is_static_mask,
client=client,
eval_func="apply",
)
# Apply the primitive
out = tesseract_dispatch_p.bind(
*array_args,
static_args=static_args,
input_pytreedef=input_pytreedef,
output_pytreedef=output_pytreedef,
output_avals=flat_avals,
is_static_mask=is_static_mask,
client=client,
eval_func="apply",
)

# Unflatten the output
return jax.tree.unflatten(output_pytreedef, out)
# Unflatten the output
return jax.tree.unflatten(output_pytreedef, out)

else:
# Apply the primitive
out = tesseract_dispatch_p.bind(
*array_args,
static_args=static_args,
input_pytreedef=input_pytreedef,
output_pytreedef=None,
output_avals=None,
is_static_mask=is_static_mask,
client=client,
eval_func="apply",
)

# Unflatten the output
return out
14 changes: 12 additions & 2 deletions tesseract_jax/tesseract_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,8 @@ def apply(
array_args: tuple[ArrayLike, ...],
static_args: tuple[Any, ...],
input_pytreedef: PyTreeDef,
output_pytreedef: PyTreeDef,
output_avals: tuple[ShapeDtypeStruct, ...],
output_pytreedef: PyTreeDef | None,
output_avals: tuple[ShapeDtypeStruct, ...] | None,
is_static_mask: tuple[bool, ...],
) -> PyTree:
"""Call the Tesseract's apply endpoint with the given arguments."""
Expand All @@ -159,9 +159,19 @@ def apply(

out_data = self.client.apply(inputs)

if output_avals is None:
return out_data

out_data = tuple(jax.tree.flatten(out_data)[0])
return out_data

def apply_pytree(
self,
inputs: PyTree,
) -> PyTree:
"""Call the Tesseract's apply endpoint with the given arguments."""
return self.client.apply(inputs)

def jacobian_vector_product(
self,
array_args: tuple[ArrayLike, ...],
Expand Down
1 change: 1 addition & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,4 @@ def served_tesseract():

served_univariate_tesseract_raw = make_tesseract_fixture("univariate_tesseract")
served_nested_tesseract_raw = make_tesseract_fixture("nested_tesseract")
served_non_abstract_tesseract = make_tesseract_fixture("non_abstract_tesseract")
41 changes: 41 additions & 0 deletions tests/non_abstract_tesseract/tesseract_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0


from typing import Any

import jax.numpy as jnp
from pydantic import BaseModel, Field
from tesseract_core.runtime import Array, Differentiable, Float32


class InputSchema(BaseModel):
a: Differentiable[Array[(None,), Float32]] = Field(
description="An arbitrary vector"
)


class OutputSchema(BaseModel):
b: Differentiable[Array[(None,), Float32]] = Field(
description="Vector s_a·a + s_b·b"
)
c: Array[(None,), Float32] = Field(description="Constant vector [1.0, 1.0, 1.0]")


def apply(inputs: InputSchema) -> OutputSchema:
"""Multiplies a vector `a` by `s`, and sums the result to `b`."""
return OutputSchema(
b=2.0 * inputs.a,
c=jnp.array([1.0, 1.0, 1.0], dtype="float32"),
)


def vector_jacobian_product(
inputs: InputSchema,
vjp_inputs: set[str],
vjp_outputs: set[str],
cotangent_vector: dict[str, Any],
):
return {
"a": 2.0 * cotangent_vector["b"],
}
9 changes: 9 additions & 0 deletions tests/non_abstract_tesseract/tesseract_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
name: non_abstract_tesseract
version: "2025-02-05"
description: |
Tesseract that adds/subtracts two vectors. Uses jax internally.

build_config:
target_platform: "native"
# package_data: []
# custom_build_steps: []
2 changes: 2 additions & 0 deletions tests/non_abstract_tesseract/tesseract_requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
jax[cpu]
equinox
39 changes: 39 additions & 0 deletions tests/test_endtoend.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,3 +538,42 @@ def f(x, y, tess):
result = f(x, y, tess)
result_ref = rosenbrock_impl(x, y)
_assert_pytree_isequal(result, result_ref)


@pytest.mark.parametrize("use_jit", [True, False])
def test_non_abstract_tesseract_apply(served_non_abstract_tesseract, use_jit):
non_abstract_tess = Tesseract(served_non_abstract_tesseract)
a = np.array([0.0, 1.0, 2.0], dtype="float32")

def f(a):
return apply_tesseract(non_abstract_tess, inputs=dict(a=a))

if use_jit:
f = jax.jit(f)

# make sure value error is raised if input shape is incorrect
with pytest.raises(ValueError):
f(a)

else:
# Test against Tesseract client
result = f(a)
result_ref = non_abstract_tess.apply(dict(a=a))
_assert_pytree_isequal(result, result_ref)


def test_non_abstract_tesseract_vjp(served_non_abstract_tesseract):
non_abstract_tess = Tesseract(served_non_abstract_tesseract)

a = np.array([1.0, 2.0, 3.0], dtype="float32")

def f(a):
return apply_tesseract(
non_abstract_tess,
inputs=dict(
a=a,
),
)

with pytest.raises(ValueError):
jax.vjp(f, a)
Loading