@@ -266,9 +266,6 @@ def _dispatch(*args: ArrayLike) -> Any:
266266mlir .register_lowering (tesseract_dispatch_p , tesseract_dispatch_lowering )
267267
268268
269- mlir .register_lowering (tesseract_dispatch_p , tesseract_dispatch_lowering )
270-
271-
272269def tesseract_dispatch_batching (
273270 array_args : ArrayLike | ShapedArray | Any ,
274271 axes : Sequence [Any ],
@@ -380,16 +377,19 @@ def apply_tesseract(
380377 f"Got { type (tesseract_client )} instead."
381378 )
382379
383- transformation = False
380+ has_func_transformation = False
384381
385382 # determine if any array in the input pytree is a tracer
386383 inputs_flat , _ = jax .tree .flatten (inputs )
387384 for inp in inputs_flat :
388385 if isinstance (inp , jc .Tracer ):
389- transformation = True
386+ has_func_transformation = True
390387 break
391388
392- if transformation and "abstract_eval" not in tesseract_client .available_endpoints :
389+ if (
390+ has_func_transformation
391+ and "abstract_eval" not in tesseract_client .available_endpoints
392+ ):
393393 raise ValueError (
394394 "Given Tesseract object does not support abstract_eval, "
395395 "it is however called in combination with a JAX transformation "
@@ -400,8 +400,6 @@ def apply_tesseract(
400400
401401 client = Jaxeract (tesseract_client )
402402
403- # Get abstract values for outputs, so we can unflatten them later
404-
405403 flat_args , input_pytreedef = jax .tree .flatten (inputs )
406404 is_static_mask = tuple (_is_static (arg ) for arg in flat_args )
407405 array_args , static_args = split_args (flat_args , is_static_mask )
@@ -447,7 +445,9 @@ def apply_tesseract(
447445 return jax .tree .unflatten (output_pytreedef , out )
448446
449447 else :
450- # Apply the primitive
448+ # If there is no abstract_eval endpoint, we cannot determine the output structure
449+ # In this case we send None for output_pytreedef and output_avals
450+ # and the primitive will return an unflattened output
451451 out = tesseract_dispatch_p .bind (
452452 * array_args ,
453453 static_args = static_args ,
0 commit comments