Skip to content

Commit 50971a9

Browse files
committed
address comments
1 parent 6442059 commit 50971a9

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

tesseract_jax/primitive.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -266,9 +266,6 @@ def _dispatch(*args: ArrayLike) -> Any:
266266
mlir.register_lowering(tesseract_dispatch_p, tesseract_dispatch_lowering)
267267

268268

269-
mlir.register_lowering(tesseract_dispatch_p, tesseract_dispatch_lowering)
270-
271-
272269
def tesseract_dispatch_batching(
273270
array_args: ArrayLike | ShapedArray | Any,
274271
axes: Sequence[Any],
@@ -380,16 +377,19 @@ def apply_tesseract(
380377
f"Got {type(tesseract_client)} instead."
381378
)
382379

383-
transformation = False
380+
has_func_transformation = False
384381

385382
# determine if any array in the input pytree is a tracer
386383
inputs_flat, _ = jax.tree.flatten(inputs)
387384
for inp in inputs_flat:
388385
if isinstance(inp, jc.Tracer):
389-
transformation = True
386+
has_func_transformation = True
390387
break
391388

392-
if transformation and "abstract_eval" not in tesseract_client.available_endpoints:
389+
if (
390+
has_func_transformation
391+
and "abstract_eval" not in tesseract_client.available_endpoints
392+
):
393393
raise ValueError(
394394
"Given Tesseract object does not support abstract_eval, "
395395
"it is however called in combination with a JAX transformation "
@@ -400,8 +400,6 @@ def apply_tesseract(
400400

401401
client = Jaxeract(tesseract_client)
402402

403-
# Get abstract values for outputs, so we can unflatten them later
404-
405403
flat_args, input_pytreedef = jax.tree.flatten(inputs)
406404
is_static_mask = tuple(_is_static(arg) for arg in flat_args)
407405
array_args, static_args = split_args(flat_args, is_static_mask)
@@ -447,7 +445,9 @@ def apply_tesseract(
447445
return jax.tree.unflatten(output_pytreedef, out)
448446

449447
else:
450-
# Apply the primitive
448+
# If there is no abstract_eval endpoint, we cannot determine the output structure
449+
# In this case we send None for output_pytreedef and output_avals
450+
# and the primitive will return an unflattened output
451451
out = tesseract_dispatch_p.bind(
452452
*array_args,
453453
static_args=static_args,

0 commit comments

Comments
 (0)