Skip to content

Commit 2864729

Browse files
authored
feat: Selective abstract eval (#85)
#### Relevant issue or PR Tesseract need to have an abstract eval defined, even if jax would not require them to. #### Description of changes This PR avoids calls to abstract_eval if it is not defined in the Tesseract API. If a function transformation is applied without the abstract eval, an error is raised. Furthermore we add a concrete implementation "tesseract_dispatch()" that does not rely on any function transformations. #### Testing done added new tests
1 parent 574374a commit 2864729

File tree

7 files changed

+212
-43
lines changed

7 files changed

+212
-43
lines changed

tesseract_jax/primitive.py

Lines changed: 108 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Copyright 2025 Pasteur Labs. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4-
import functools
54
from collections.abc import Sequence
65
from typing import Any, TypeVar
76

7+
import jax.core as jc
88
import jax.numpy as jnp
99
import jax.tree
1010
import numpy as np
1111
from jax import ShapeDtypeStruct, dtypes, extend
1212
from jax.core import ShapedArray
13-
from jax.interpreters import ad, batching, mlir, xla
13+
from jax.interpreters import ad, batching, mlir
1414
from jax.tree_util import PyTreeDef
1515
from jax.typing import ArrayLike
1616
from tesseract_core import Tesseract
@@ -21,9 +21,6 @@
2121

2222
tesseract_dispatch_p = extend.core.Primitive("tesseract_dispatch")
2323
tesseract_dispatch_p.multiple_results = True
24-
tesseract_dispatch_p.def_impl(
25-
functools.partial(xla.apply_primitive, tesseract_dispatch_p)
26-
)
2724

2825

2926
class _Hashable:
@@ -184,12 +181,48 @@ def tesseract_dispatch_transpose_rule(
184181
# I see it chokes on map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out),
185182
# where eqn.invars ends up being longer than cts_out.
186183

187-
return tuple([None] * len(args) + vjp)
184+
return tuple([None] * len(args) + list(vjp))
188185

189186

190187
ad.primitive_transposes[tesseract_dispatch_p] = tesseract_dispatch_transpose_rule
191188

192189

190+
def tesseract_dispatch(
191+
*array_args: ArrayLike | ShapedArray | Any,
192+
static_args: tuple[_Hashable, ...],
193+
input_pytreedef: PyTreeDef,
194+
output_pytreedef: PyTreeDef | None,
195+
output_avals: tuple[ShapeDtypeStruct, ...] | None,
196+
is_static_mask: tuple[bool, ...],
197+
client: Jaxeract,
198+
eval_func: str,
199+
) -> Any:
200+
"""Defines how to dispatch lowering the computation.
201+
202+
The dispatch that is not lowered is only called in cases where abstract eval is not needed.
203+
"""
204+
205+
def _dispatch(*args: ArrayLike) -> Any:
206+
static_args_ = tuple(_unpack_hashable(arg) for arg in static_args)
207+
out = getattr(client, eval_func)(
208+
args,
209+
static_args_,
210+
input_pytreedef,
211+
output_pytreedef,
212+
output_avals,
213+
is_static_mask,
214+
)
215+
if not isinstance(out, tuple) and output_avals is not None:
216+
out = (out,)
217+
return out
218+
219+
result = _dispatch(*array_args)
220+
return result
221+
222+
223+
tesseract_dispatch_p.def_impl(tesseract_dispatch)
224+
225+
193226
def tesseract_dispatch_lowering(
194227
ctx: Any,
195228
*array_args: ArrayLike | ShapedArray | Any,
@@ -344,10 +377,25 @@ def apply_tesseract(
344377
f"Got {type(tesseract_client)} instead."
345378
)
346379

347-
if "abstract_eval" not in tesseract_client.available_endpoints:
380+
has_func_transformation = False
381+
382+
# determine if any array in the input pytree is a tracer
383+
inputs_flat, _ = jax.tree.flatten(inputs)
384+
for inp in inputs_flat:
385+
if isinstance(inp, jc.Tracer):
386+
has_func_transformation = True
387+
break
388+
389+
if (
390+
has_func_transformation
391+
and "abstract_eval" not in tesseract_client.available_endpoints
392+
):
348393
raise ValueError(
349394
"Given Tesseract object does not support abstract_eval, "
350-
"which is required for compatibility with JAX."
395+
"it is however called in combination with a JAX transformation "
396+
"like jit, grad, vmap, or pmap. "
397+
"Either remove the transformation or add an abstract_eval endpoint "
398+
"to the Tesseract object."
351399
)
352400

353401
client = Jaxeract(tesseract_client)
@@ -357,40 +405,59 @@ def apply_tesseract(
357405
array_args, static_args = split_args(flat_args, is_static_mask)
358406
static_args = tuple(_make_hashable(arg) for arg in static_args)
359407

360-
# Get abstract values for outputs, so we can unflatten them later
361-
output_pytreedef, avals = None, None
362-
avals = client.abstract_eval(
363-
array_args,
364-
static_args,
365-
input_pytreedef,
366-
output_pytreedef,
367-
avals,
368-
is_static_mask,
369-
)
408+
if "abstract_eval" in tesseract_client.available_endpoints:
409+
# Get abstract values for outputs, so we can unflatten them later
410+
output_pytreedef, avals = None, None
411+
avals = client.abstract_eval(
412+
array_args,
413+
static_args,
414+
input_pytreedef,
415+
output_pytreedef,
416+
avals,
417+
is_static_mask,
418+
)
370419

371-
is_aval = lambda x: isinstance(x, dict) and "dtype" in x and "shape" in x
372-
flat_avals, output_pytreedef = jax.tree.flatten(avals, is_leaf=is_aval)
373-
for aval in flat_avals:
374-
if not is_aval(aval):
375-
continue
376-
_check_dtype(aval["dtype"])
420+
is_aval = lambda x: isinstance(x, dict) and "dtype" in x and "shape" in x
421+
flat_avals, output_pytreedef = jax.tree.flatten(avals, is_leaf=is_aval)
422+
for aval in flat_avals:
423+
if not is_aval(aval):
424+
continue
425+
_check_dtype(aval["dtype"])
377426

378-
flat_avals = tuple(
379-
jax.ShapeDtypeStruct(shape=tuple(aval["shape"]), dtype=aval["dtype"])
380-
for aval in flat_avals
381-
)
427+
flat_avals = tuple(
428+
jax.ShapeDtypeStruct(shape=tuple(aval["shape"]), dtype=aval["dtype"])
429+
for aval in flat_avals
430+
)
382431

383-
# Apply the primitive
384-
out = tesseract_dispatch_p.bind(
385-
*array_args,
386-
static_args=static_args,
387-
input_pytreedef=input_pytreedef,
388-
output_pytreedef=output_pytreedef,
389-
output_avals=flat_avals,
390-
is_static_mask=is_static_mask,
391-
client=client,
392-
eval_func="apply",
393-
)
432+
# Apply the primitive
433+
out = tesseract_dispatch_p.bind(
434+
*array_args,
435+
static_args=static_args,
436+
input_pytreedef=input_pytreedef,
437+
output_pytreedef=output_pytreedef,
438+
output_avals=flat_avals,
439+
is_static_mask=is_static_mask,
440+
client=client,
441+
eval_func="apply",
442+
)
443+
444+
# Unflatten the output
445+
return jax.tree.unflatten(output_pytreedef, out)
394446

395-
# Unflatten the output
396-
return jax.tree.unflatten(output_pytreedef, out)
447+
else:
448+
# If there is no abstract_eval endpoint, we cannot determine the output structure
449+
# In this case we send None for output_pytreedef and output_avals
450+
# and the primitive will return an unflattened output
451+
out = tesseract_dispatch_p.bind(
452+
*array_args,
453+
static_args=static_args,
454+
input_pytreedef=input_pytreedef,
455+
output_pytreedef=None,
456+
output_avals=None,
457+
is_static_mask=is_static_mask,
458+
client=client,
459+
eval_func="apply",
460+
)
461+
462+
# Unflatten the output
463+
return out

tesseract_jax/tesseract_compat.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def apply(
148148
array_args: tuple[ArrayLike, ...],
149149
static_args: tuple[Any, ...],
150150
input_pytreedef: PyTreeDef,
151-
output_pytreedef: PyTreeDef,
152-
output_avals: tuple[ShapeDtypeStruct, ...],
151+
output_pytreedef: PyTreeDef | None,
152+
output_avals: tuple[ShapeDtypeStruct, ...] | None,
153153
is_static_mask: tuple[bool, ...],
154154
) -> PyTree:
155155
"""Call the Tesseract's apply endpoint with the given arguments."""
@@ -159,9 +159,19 @@ def apply(
159159

160160
out_data = self.client.apply(inputs)
161161

162+
if output_avals is None:
163+
return out_data
164+
162165
out_data = tuple(jax.tree.flatten(out_data)[0])
163166
return out_data
164167

168+
def apply_pytree(
169+
self,
170+
inputs: PyTree,
171+
) -> PyTree:
172+
"""Call the Tesseract's apply endpoint with the given arguments."""
173+
return self.client.apply(inputs)
174+
165175
def jacobian_vector_product(
166176
self,
167177
array_args: tuple[ArrayLike, ...],

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,3 +85,4 @@ def served_tesseract():
8585

8686
served_univariate_tesseract_raw = make_tesseract_fixture("univariate_tesseract")
8787
served_nested_tesseract_raw = make_tesseract_fixture("nested_tesseract")
88+
served_non_abstract_tesseract = make_tesseract_fixture("non_abstract_tesseract")
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright 2025 Pasteur Labs. All Rights Reserved.
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
5+
from typing import Any
6+
7+
import jax.numpy as jnp
8+
from pydantic import BaseModel, Field
9+
from tesseract_core.runtime import Array, Differentiable, Float32
10+
11+
12+
class InputSchema(BaseModel):
13+
a: Differentiable[Array[(None,), Float32]] = Field(
14+
description="An arbitrary vector"
15+
)
16+
17+
18+
class OutputSchema(BaseModel):
19+
b: Differentiable[Array[(None,), Float32]] = Field(
20+
description="Vector s_a·a + s_b·b"
21+
)
22+
c: Array[(None,), Float32] = Field(description="Constant vector [1.0, 1.0, 1.0]")
23+
24+
25+
def apply(inputs: InputSchema) -> OutputSchema:
26+
"""Multiplies a vector `a` by `s`, and sums the result to `b`."""
27+
return OutputSchema(
28+
b=2.0 * inputs.a,
29+
c=jnp.array([1.0, 1.0, 1.0], dtype="float32"),
30+
)
31+
32+
33+
def vector_jacobian_product(
34+
inputs: InputSchema,
35+
vjp_inputs: set[str],
36+
vjp_outputs: set[str],
37+
cotangent_vector: dict[str, Any],
38+
):
39+
return {
40+
"a": 2.0 * cotangent_vector["b"],
41+
}
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
name: non_abstract_tesseract
2+
version: "2025-02-05"
3+
description: |
4+
Tesseract that adds/subtracts two vectors. Uses jax internally.
5+
6+
build_config:
7+
target_platform: "native"
8+
# package_data: []
9+
# custom_build_steps: []
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
jax[cpu]
2+
equinox

tests/test_endtoend.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -538,3 +538,42 @@ def f(x, y, tess):
538538
result = f(x, y, tess)
539539
result_ref = rosenbrock_impl(x, y)
540540
_assert_pytree_isequal(result, result_ref)
541+
542+
543+
@pytest.mark.parametrize("use_jit", [True, False])
544+
def test_non_abstract_tesseract_apply(served_non_abstract_tesseract, use_jit):
545+
non_abstract_tess = Tesseract(served_non_abstract_tesseract)
546+
a = np.array([0.0, 1.0, 2.0], dtype="float32")
547+
548+
def f(a):
549+
return apply_tesseract(non_abstract_tess, inputs=dict(a=a))
550+
551+
if use_jit:
552+
f = jax.jit(f)
553+
554+
# make sure value error is raised if input shape is incorrect
555+
with pytest.raises(ValueError):
556+
f(a)
557+
558+
else:
559+
# Test against Tesseract client
560+
result = f(a)
561+
result_ref = non_abstract_tess.apply(dict(a=a))
562+
_assert_pytree_isequal(result, result_ref)
563+
564+
565+
def test_non_abstract_tesseract_vjp(served_non_abstract_tesseract):
566+
non_abstract_tess = Tesseract(served_non_abstract_tesseract)
567+
568+
a = np.array([1.0, 2.0, 3.0], dtype="float32")
569+
570+
def f(a):
571+
return apply_tesseract(
572+
non_abstract_tess,
573+
inputs=dict(
574+
a=a,
575+
),
576+
)
577+
578+
with pytest.raises(ValueError):
579+
jax.vjp(f, a)

0 commit comments

Comments
 (0)