Skip to content

Commit 5a7807d

Browse files
committed
optional abstract eval
1 parent 89d3616 commit 5a7807d

File tree

4 files changed

+61
-45
lines changed

4 files changed

+61
-45
lines changed

examples/simple/partial/tesseract_api.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,10 @@ def vector_jacobian_product(
4444

4545

4646

47-
def abstract_eval(abstract_inputs):
48-
"""Calculate output shape of apply from the shape of its inputs."""
49-
return {
50-
"b": ShapeDType(shape=(abstract_inputs.a.shape[0],), dtype="float32"),
51-
# "c": ShapeDType(shape=(3,), dtype="float32"),
52-
}
47+
# def abstract_eval(abstract_inputs):
48+
# """Calculate output shape of apply from the shape of its inputs."""
49+
# return {
50+
# "b": ShapeDType(shape=(abstract_inputs.a.shape[0],), dtype="float32"),
51+
# # "c": ShapeDType(shape=(3,), dtype="float32"),
52+
# }
5353

tesseract_jax/primitive.py

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -195,8 +195,8 @@ def tesseract_dispatch_lowering(
195195
*array_args: ArrayLike | ShapedArray | Any,
196196
static_args: tuple[_Hashable, ...],
197197
input_pytreedef: PyTreeDef,
198-
output_pytreedef: PyTreeDef,
199-
output_avals: tuple[ShapeDtypeStruct, ...],
198+
output_pytreedef: PyTreeDef | None,
199+
output_avals: tuple[ShapeDtypeStruct, ...] | None,
200200
is_static_mask: tuple[bool, ...],
201201
client: Jaxeract,
202202
eval_func: str,
@@ -344,11 +344,11 @@ def apply_tesseract(
344344
f"Got {type(tesseract_client)} instead."
345345
)
346346

347-
if "abstract_eval" not in tesseract_client.available_endpoints:
348-
raise ValueError(
349-
"Given Tesseract object does not support abstract_eval, "
350-
"which is required for compatibility with JAX."
351-
)
347+
# if "abstract_eval" not in tesseract_client.available_endpoints:
348+
# raise ValueError(
349+
# "Given Tesseract object does not support abstract_eval, "
350+
# "which is required for compatibility with JAX."
351+
# )
352352

353353
client = Jaxeract(tesseract_client)
354354

@@ -357,28 +357,30 @@ def apply_tesseract(
357357
array_args, static_args = split_args(flat_args, is_static_mask)
358358
static_args = tuple(_make_hashable(arg) for arg in static_args)
359359

360+
360361
# Get abstract values for outputs, so we can unflatten them later
361-
output_pytreedef, avals = None, None
362-
avals = client.abstract_eval(
363-
array_args,
364-
static_args,
365-
input_pytreedef,
366-
output_pytreedef,
367-
avals,
368-
is_static_mask,
369-
)
362+
output_pytreedef, avals, flat_avals = None, None, None
363+
if "abstract_eval" in tesseract_client.available_endpoints:
364+
avals = client.abstract_eval(
365+
array_args,
366+
static_args,
367+
input_pytreedef,
368+
output_pytreedef,
369+
avals,
370+
is_static_mask,
371+
)
370372

371-
is_aval = lambda x: isinstance(x, dict) and "dtype" in x and "shape" in x
372-
flat_avals, output_pytreedef = jax.tree.flatten(avals, is_leaf=is_aval)
373-
for aval in flat_avals:
374-
if not is_aval(aval):
375-
continue
376-
_check_dtype(aval["dtype"])
373+
is_aval = lambda x: isinstance(x, dict) and "dtype" in x and "shape" in x
374+
flat_avals, output_pytreedef = jax.tree.flatten(avals, is_leaf=is_aval)
375+
for aval in flat_avals:
376+
if not is_aval(aval):
377+
continue
378+
_check_dtype(aval["dtype"])
377379

378-
flat_avals = tuple(
379-
jax.ShapeDtypeStruct(shape=tuple(aval["shape"]), dtype=aval["dtype"])
380-
for aval in flat_avals
381-
)
380+
flat_avals = tuple(
381+
jax.ShapeDtypeStruct(shape=tuple(aval["shape"]), dtype=aval["dtype"])
382+
for aval in flat_avals
383+
)
382384

383385
# Apply the primitive
384386
out = tesseract_dispatch_p.bind(
@@ -392,5 +394,16 @@ def apply_tesseract(
392394
eval_func="apply",
393395
)
394396

397+
# else:
398+
# # Apply the primitive
399+
# out = tesseract_dispatch_p.bind(
400+
# *array_args,
401+
# static_args=static_args,
402+
# input_pytreedef=input_pytreedef,
403+
# is_static_mask=is_static_mask,
404+
# client=client,
405+
# eval_func="apply",
406+
# )
407+
395408
# Unflatten the output
396409
return jax.tree.unflatten(output_pytreedef, out)

tesseract_jax/tesseract_compat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,8 @@ def apply(
148148
array_args: tuple[ArrayLike, ...],
149149
static_args: tuple[Any, ...],
150150
input_pytreedef: PyTreeDef,
151-
output_pytreedef: PyTreeDef,
152-
output_avals: tuple[ShapeDtypeStruct, ...],
151+
output_pytreedef: PyTreeDef | None,
152+
output_avals: tuple[ShapeDtypeStruct, ...] | None,
153153
is_static_mask: tuple[bool, ...],
154154
) -> PyTree:
155155
"""Call the Tesseract's apply endpoint with the given arguments."""

0 commit comments

Comments
 (0)