diff --git a/tesseract_jax/primitive.py b/tesseract_jax/primitive.py index 78e90eb..cfe844f 100644 --- a/tesseract_jax/primitive.py +++ b/tesseract_jax/primitive.py @@ -1,16 +1,16 @@ # Copyright 2025 Pasteur Labs. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import functools from collections.abc import Sequence from typing import Any, TypeVar +import jax.core as jc import jax.numpy as jnp import jax.tree import numpy as np from jax import ShapeDtypeStruct, dtypes, extend from jax.core import ShapedArray -from jax.interpreters import ad, batching, mlir, xla +from jax.interpreters import ad, batching, mlir from jax.tree_util import PyTreeDef from jax.typing import ArrayLike from tesseract_core import Tesseract @@ -21,9 +21,6 @@ tesseract_dispatch_p = extend.core.Primitive("tesseract_dispatch") tesseract_dispatch_p.multiple_results = True -tesseract_dispatch_p.def_impl( - functools.partial(xla.apply_primitive, tesseract_dispatch_p) -) class _Hashable: @@ -184,12 +181,48 @@ def tesseract_dispatch_transpose_rule( # I see it chokes on map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out), # where eqn.invars ends up being longer than cts_out. - return tuple([None] * len(args) + vjp) + return tuple([None] * len(args) + list(vjp)) ad.primitive_transposes[tesseract_dispatch_p] = tesseract_dispatch_transpose_rule +def tesseract_dispatch( + *array_args: ArrayLike | ShapedArray | Any, + static_args: tuple[_Hashable, ...], + input_pytreedef: PyTreeDef, + output_pytreedef: PyTreeDef | None, + output_avals: tuple[ShapeDtypeStruct, ...] | None, + is_static_mask: tuple[bool, ...], + client: Jaxeract, + eval_func: str, +) -> Any: + """Defines how to dispatch lowering the computation. + + The dispatch that is not lowered is only called in cases where abstract eval is not needed. + """ + + def _dispatch(*args: ArrayLike) -> Any: + static_args_ = tuple(_unpack_hashable(arg) for arg in static_args) + out = getattr(client, eval_func)( + args, + static_args_, + input_pytreedef, + output_pytreedef, + output_avals, + is_static_mask, + ) + if not isinstance(out, tuple) and output_avals is not None: + out = (out,) + return out + + result = _dispatch(*array_args) + return result + + +tesseract_dispatch_p.def_impl(tesseract_dispatch) + + def tesseract_dispatch_lowering( ctx: Any, *array_args: ArrayLike | ShapedArray | Any, @@ -344,10 +377,25 @@ def apply_tesseract( f"Got {type(tesseract_client)} instead." ) - if "abstract_eval" not in tesseract_client.available_endpoints: + has_func_transformation = False + + # determine if any array in the input pytree is a tracer + inputs_flat, _ = jax.tree.flatten(inputs) + for inp in inputs_flat: + if isinstance(inp, jc.Tracer): + has_func_transformation = True + break + + if ( + has_func_transformation + and "abstract_eval" not in tesseract_client.available_endpoints + ): raise ValueError( "Given Tesseract object does not support abstract_eval, " - "which is required for compatibility with JAX." + "it is however called in combination with a JAX transformation " + "like jit, grad, vmap, or pmap. " + "Either remove the transformation or add an abstract_eval endpoint " + "to the Tesseract object." ) client = Jaxeract(tesseract_client) @@ -357,40 +405,59 @@ def apply_tesseract( array_args, static_args = split_args(flat_args, is_static_mask) static_args = tuple(_make_hashable(arg) for arg in static_args) - # Get abstract values for outputs, so we can unflatten them later - output_pytreedef, avals = None, None - avals = client.abstract_eval( - array_args, - static_args, - input_pytreedef, - output_pytreedef, - avals, - is_static_mask, - ) + if "abstract_eval" in tesseract_client.available_endpoints: + # Get abstract values for outputs, so we can unflatten them later + output_pytreedef, avals = None, None + avals = client.abstract_eval( + array_args, + static_args, + input_pytreedef, + output_pytreedef, + avals, + is_static_mask, + ) - is_aval = lambda x: isinstance(x, dict) and "dtype" in x and "shape" in x - flat_avals, output_pytreedef = jax.tree.flatten(avals, is_leaf=is_aval) - for aval in flat_avals: - if not is_aval(aval): - continue - _check_dtype(aval["dtype"]) + is_aval = lambda x: isinstance(x, dict) and "dtype" in x and "shape" in x + flat_avals, output_pytreedef = jax.tree.flatten(avals, is_leaf=is_aval) + for aval in flat_avals: + if not is_aval(aval): + continue + _check_dtype(aval["dtype"]) - flat_avals = tuple( - jax.ShapeDtypeStruct(shape=tuple(aval["shape"]), dtype=aval["dtype"]) - for aval in flat_avals - ) + flat_avals = tuple( + jax.ShapeDtypeStruct(shape=tuple(aval["shape"]), dtype=aval["dtype"]) + for aval in flat_avals + ) - # Apply the primitive - out = tesseract_dispatch_p.bind( - *array_args, - static_args=static_args, - input_pytreedef=input_pytreedef, - output_pytreedef=output_pytreedef, - output_avals=flat_avals, - is_static_mask=is_static_mask, - client=client, - eval_func="apply", - ) + # Apply the primitive + out = tesseract_dispatch_p.bind( + *array_args, + static_args=static_args, + input_pytreedef=input_pytreedef, + output_pytreedef=output_pytreedef, + output_avals=flat_avals, + is_static_mask=is_static_mask, + client=client, + eval_func="apply", + ) + + # Unflatten the output + return jax.tree.unflatten(output_pytreedef, out) - # Unflatten the output - return jax.tree.unflatten(output_pytreedef, out) + else: + # If there is no abstract_eval endpoint, we cannot determine the output structure + # In this case we send None for output_pytreedef and output_avals + # and the primitive will return an unflattened output + out = tesseract_dispatch_p.bind( + *array_args, + static_args=static_args, + input_pytreedef=input_pytreedef, + output_pytreedef=None, + output_avals=None, + is_static_mask=is_static_mask, + client=client, + eval_func="apply", + ) + + # Unflatten the output + return out diff --git a/tesseract_jax/tesseract_compat.py b/tesseract_jax/tesseract_compat.py index b9c9279..b0415f4 100644 --- a/tesseract_jax/tesseract_compat.py +++ b/tesseract_jax/tesseract_compat.py @@ -148,8 +148,8 @@ def apply( array_args: tuple[ArrayLike, ...], static_args: tuple[Any, ...], input_pytreedef: PyTreeDef, - output_pytreedef: PyTreeDef, - output_avals: tuple[ShapeDtypeStruct, ...], + output_pytreedef: PyTreeDef | None, + output_avals: tuple[ShapeDtypeStruct, ...] | None, is_static_mask: tuple[bool, ...], ) -> PyTree: """Call the Tesseract's apply endpoint with the given arguments.""" @@ -159,9 +159,19 @@ def apply( out_data = self.client.apply(inputs) + if output_avals is None: + return out_data + out_data = tuple(jax.tree.flatten(out_data)[0]) return out_data + def apply_pytree( + self, + inputs: PyTree, + ) -> PyTree: + """Call the Tesseract's apply endpoint with the given arguments.""" + return self.client.apply(inputs) + def jacobian_vector_product( self, array_args: tuple[ArrayLike, ...], diff --git a/tests/conftest.py b/tests/conftest.py index d35c6ef..47998ae 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -85,3 +85,4 @@ def served_tesseract(): served_univariate_tesseract_raw = make_tesseract_fixture("univariate_tesseract") served_nested_tesseract_raw = make_tesseract_fixture("nested_tesseract") +served_non_abstract_tesseract = make_tesseract_fixture("non_abstract_tesseract") diff --git a/tests/non_abstract_tesseract/tesseract_api.py b/tests/non_abstract_tesseract/tesseract_api.py new file mode 100644 index 0000000..4bc0344 --- /dev/null +++ b/tests/non_abstract_tesseract/tesseract_api.py @@ -0,0 +1,41 @@ +# Copyright 2025 Pasteur Labs. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + + +from typing import Any + +import jax.numpy as jnp +from pydantic import BaseModel, Field +from tesseract_core.runtime import Array, Differentiable, Float32 + + +class InputSchema(BaseModel): + a: Differentiable[Array[(None,), Float32]] = Field( + description="An arbitrary vector" + ) + + +class OutputSchema(BaseModel): + b: Differentiable[Array[(None,), Float32]] = Field( + description="Vector s_a·a + s_b·b" + ) + c: Array[(None,), Float32] = Field(description="Constant vector [1.0, 1.0, 1.0]") + + +def apply(inputs: InputSchema) -> OutputSchema: + """Multiplies a vector `a` by `s`, and sums the result to `b`.""" + return OutputSchema( + b=2.0 * inputs.a, + c=jnp.array([1.0, 1.0, 1.0], dtype="float32"), + ) + + +def vector_jacobian_product( + inputs: InputSchema, + vjp_inputs: set[str], + vjp_outputs: set[str], + cotangent_vector: dict[str, Any], +): + return { + "a": 2.0 * cotangent_vector["b"], + } diff --git a/tests/non_abstract_tesseract/tesseract_config.yaml b/tests/non_abstract_tesseract/tesseract_config.yaml new file mode 100644 index 0000000..30e4ce2 --- /dev/null +++ b/tests/non_abstract_tesseract/tesseract_config.yaml @@ -0,0 +1,9 @@ +name: non_abstract_tesseract +version: "2025-02-05" +description: | + Tesseract that adds/subtracts two vectors. Uses jax internally. + +build_config: + target_platform: "native" + # package_data: [] + # custom_build_steps: [] diff --git a/tests/non_abstract_tesseract/tesseract_requirements.txt b/tests/non_abstract_tesseract/tesseract_requirements.txt new file mode 100644 index 0000000..3878018 --- /dev/null +++ b/tests/non_abstract_tesseract/tesseract_requirements.txt @@ -0,0 +1,2 @@ +jax[cpu] +equinox diff --git a/tests/test_endtoend.py b/tests/test_endtoend.py index 92cb5f3..f6eebbd 100644 --- a/tests/test_endtoend.py +++ b/tests/test_endtoend.py @@ -538,3 +538,42 @@ def f(x, y, tess): result = f(x, y, tess) result_ref = rosenbrock_impl(x, y) _assert_pytree_isequal(result, result_ref) + + +@pytest.mark.parametrize("use_jit", [True, False]) +def test_non_abstract_tesseract_apply(served_non_abstract_tesseract, use_jit): + non_abstract_tess = Tesseract(served_non_abstract_tesseract) + a = np.array([0.0, 1.0, 2.0], dtype="float32") + + def f(a): + return apply_tesseract(non_abstract_tess, inputs=dict(a=a)) + + if use_jit: + f = jax.jit(f) + + # make sure value error is raised if input shape is incorrect + with pytest.raises(ValueError): + f(a) + + else: + # Test against Tesseract client + result = f(a) + result_ref = non_abstract_tess.apply(dict(a=a)) + _assert_pytree_isequal(result, result_ref) + + +def test_non_abstract_tesseract_vjp(served_non_abstract_tesseract): + non_abstract_tess = Tesseract(served_non_abstract_tesseract) + + a = np.array([1.0, 2.0, 3.0], dtype="float32") + + def f(a): + return apply_tesseract( + non_abstract_tess, + inputs=dict( + a=a, + ), + ) + + with pytest.raises(ValueError): + jax.vjp(f, a)