11# Copyright 2025 Pasteur Labs. All Rights Reserved.
22# SPDX-License-Identifier: Apache-2.0
33
4- import functools
54from collections .abc import Sequence
65from typing import Any , TypeVar
76
7+ import jax .core as jc
88import jax .numpy as jnp
99import jax .tree
1010import numpy as np
1111from jax import ShapeDtypeStruct , dtypes , extend
1212from jax .core import ShapedArray
13- from jax .interpreters import ad , batching , mlir , xla
13+ from jax .interpreters import ad , batching , mlir
1414from jax .tree_util import PyTreeDef
1515from jax .typing import ArrayLike
1616from tesseract_core import Tesseract
2121
2222tesseract_dispatch_p = extend .core .Primitive ("tesseract_dispatch" )
2323tesseract_dispatch_p .multiple_results = True
24- tesseract_dispatch_p .def_impl (
25- functools .partial (xla .apply_primitive , tesseract_dispatch_p )
26- )
2724
2825
2926class _Hashable :
@@ -184,12 +181,48 @@ def tesseract_dispatch_transpose_rule(
184181 # I see it chokes on map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out),
185182 # where eqn.invars ends up being longer than cts_out.
186183
187- return tuple ([None ] * len (args ) + vjp )
184+ return tuple ([None ] * len (args ) + list ( vjp ) )
188185
189186
190187ad .primitive_transposes [tesseract_dispatch_p ] = tesseract_dispatch_transpose_rule
191188
192189
190+ def tesseract_dispatch (
191+ * array_args : ArrayLike | ShapedArray | Any ,
192+ static_args : tuple [_Hashable , ...],
193+ input_pytreedef : PyTreeDef ,
194+ output_pytreedef : PyTreeDef | None ,
195+ output_avals : tuple [ShapeDtypeStruct , ...] | None ,
196+ is_static_mask : tuple [bool , ...],
197+ client : Jaxeract ,
198+ eval_func : str ,
199+ ) -> Any :
200+ """Defines how to dispatch lowering the computation.
201+
202+ The dispatch that is not lowered is only called in cases where abstract eval is not needed.
203+ """
204+
205+ def _dispatch (* args : ArrayLike ) -> Any :
206+ static_args_ = tuple (_unpack_hashable (arg ) for arg in static_args )
207+ out = getattr (client , eval_func )(
208+ args ,
209+ static_args_ ,
210+ input_pytreedef ,
211+ output_pytreedef ,
212+ output_avals ,
213+ is_static_mask ,
214+ )
215+ if not isinstance (out , tuple ) and output_avals is not None :
216+ out = (out ,)
217+ return out
218+
219+ result = _dispatch (* array_args )
220+ return result
221+
222+
223+ tesseract_dispatch_p .def_impl (tesseract_dispatch )
224+
225+
193226def tesseract_dispatch_lowering (
194227 ctx : Any ,
195228 * array_args : ArrayLike | ShapedArray | Any ,
@@ -344,10 +377,25 @@ def apply_tesseract(
344377 f"Got { type (tesseract_client )} instead."
345378 )
346379
347- if "abstract_eval" not in tesseract_client .available_endpoints :
380+ has_func_transformation = False
381+
382+ # determine if any array in the input pytree is a tracer
383+ inputs_flat , _ = jax .tree .flatten (inputs )
384+ for inp in inputs_flat :
385+ if isinstance (inp , jc .Tracer ):
386+ has_func_transformation = True
387+ break
388+
389+ if (
390+ has_func_transformation
391+ and "abstract_eval" not in tesseract_client .available_endpoints
392+ ):
348393 raise ValueError (
349394 "Given Tesseract object does not support abstract_eval, "
350- "which is required for compatibility with JAX."
395+ "it is however called in combination with a JAX transformation "
396+ "like jit, grad, vmap, or pmap. "
397+ "Either remove the transformation or add an abstract_eval endpoint "
398+ "to the Tesseract object."
351399 )
352400
353401 client = Jaxeract (tesseract_client )
@@ -357,40 +405,59 @@ def apply_tesseract(
357405 array_args , static_args = split_args (flat_args , is_static_mask )
358406 static_args = tuple (_make_hashable (arg ) for arg in static_args )
359407
360- # Get abstract values for outputs, so we can unflatten them later
361- output_pytreedef , avals = None , None
362- avals = client .abstract_eval (
363- array_args ,
364- static_args ,
365- input_pytreedef ,
366- output_pytreedef ,
367- avals ,
368- is_static_mask ,
369- )
408+ if "abstract_eval" in tesseract_client .available_endpoints :
409+ # Get abstract values for outputs, so we can unflatten them later
410+ output_pytreedef , avals = None , None
411+ avals = client .abstract_eval (
412+ array_args ,
413+ static_args ,
414+ input_pytreedef ,
415+ output_pytreedef ,
416+ avals ,
417+ is_static_mask ,
418+ )
370419
371- is_aval = lambda x : isinstance (x , dict ) and "dtype" in x and "shape" in x
372- flat_avals , output_pytreedef = jax .tree .flatten (avals , is_leaf = is_aval )
373- for aval in flat_avals :
374- if not is_aval (aval ):
375- continue
376- _check_dtype (aval ["dtype" ])
420+ is_aval = lambda x : isinstance (x , dict ) and "dtype" in x and "shape" in x
421+ flat_avals , output_pytreedef = jax .tree .flatten (avals , is_leaf = is_aval )
422+ for aval in flat_avals :
423+ if not is_aval (aval ):
424+ continue
425+ _check_dtype (aval ["dtype" ])
377426
378- flat_avals = tuple (
379- jax .ShapeDtypeStruct (shape = tuple (aval ["shape" ]), dtype = aval ["dtype" ])
380- for aval in flat_avals
381- )
427+ flat_avals = tuple (
428+ jax .ShapeDtypeStruct (shape = tuple (aval ["shape" ]), dtype = aval ["dtype" ])
429+ for aval in flat_avals
430+ )
382431
383- # Apply the primitive
384- out = tesseract_dispatch_p .bind (
385- * array_args ,
386- static_args = static_args ,
387- input_pytreedef = input_pytreedef ,
388- output_pytreedef = output_pytreedef ,
389- output_avals = flat_avals ,
390- is_static_mask = is_static_mask ,
391- client = client ,
392- eval_func = "apply" ,
393- )
432+ # Apply the primitive
433+ out = tesseract_dispatch_p .bind (
434+ * array_args ,
435+ static_args = static_args ,
436+ input_pytreedef = input_pytreedef ,
437+ output_pytreedef = output_pytreedef ,
438+ output_avals = flat_avals ,
439+ is_static_mask = is_static_mask ,
440+ client = client ,
441+ eval_func = "apply" ,
442+ )
443+
444+ # Unflatten the output
445+ return jax .tree .unflatten (output_pytreedef , out )
394446
395- # Unflatten the output
396- return jax .tree .unflatten (output_pytreedef , out )
447+ else :
448+ # If there is no abstract_eval endpoint, we cannot determine the output structure
449+ # In this case we send None for output_pytreedef and output_avals
450+ # and the primitive will return an unflattened output
451+ out = tesseract_dispatch_p .bind (
452+ * array_args ,
453+ static_args = static_args ,
454+ input_pytreedef = input_pytreedef ,
455+ output_pytreedef = None ,
456+ output_avals = None ,
457+ is_static_mask = is_static_mask ,
458+ client = client ,
459+ eval_func = "apply" ,
460+ )
461+
462+ # Unflatten the output
463+ return out
0 commit comments