Skip to content

Commit 227bead

Browse files
committed
add concrete eval
1 parent 5a7807d commit 227bead

File tree

4 files changed

+192
-44
lines changed

4 files changed

+192
-44
lines changed

examples/simple/partial/tesseract_api.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,23 +31,23 @@ def apply(inputs: InputSchema) -> OutputSchema:
3131
}
3232

3333

34-
def vector_jacobian_product(
35-
inputs: InputSchema,
36-
vjp_inputs: set[str],
37-
vjp_outputs: set[str],
38-
cotangent_vector: dict[str, Any],
39-
):
34+
# def vector_jacobian_product(
35+
# inputs: InputSchema,
36+
# vjp_inputs: set[str],
37+
# vjp_outputs: set[str],
38+
# cotangent_vector: dict[str, Any],
39+
# ):
4040

41-
return {
42-
"a": 2.0 * cotangent_vector["b"],
43-
}
41+
# return {
42+
# "a": 2.0 * cotangent_vector["b"],
43+
# }
4444

4545

4646

47-
# def abstract_eval(abstract_inputs):
48-
# """Calculate output shape of apply from the shape of its inputs."""
49-
# return {
50-
# "b": ShapeDType(shape=(abstract_inputs.a.shape[0],), dtype="float32"),
51-
# # "c": ShapeDType(shape=(3,), dtype="float32"),
52-
# }
47+
def abstract_eval(abstract_inputs):
48+
"""Calculate output shape of apply from the shape of its inputs."""
49+
return {
50+
"b": ShapeDType(shape=(None,), dtype="float32"),
51+
"c": ShapeDType(shape=(None,), dtype="float32"),
52+
}
5353

tesseract_jax/primitive.py

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121

2222
tesseract_dispatch_p = extend.core.Primitive("tesseract_dispatch")
2323
tesseract_dispatch_p.multiple_results = True
24-
tesseract_dispatch_p.def_impl(
25-
functools.partial(xla.apply_primitive, tesseract_dispatch_p)
26-
)
24+
# tesseract_dispatch_p.def_impl(
25+
# functools.partial(xla.apply_primitive, tesseract_dispatch_p)
26+
# )
2727

2828

2929
class _Hashable:
@@ -58,6 +58,7 @@ def tesseract_dispatch_abstract_eval(
5858
client: Jaxeract,
5959
eval_func: str,
6060
) -> tuple:
61+
6162
"""Define how to dispatch evals and pipe arguments."""
6263
if eval_func not in (
6364
"apply",
@@ -78,6 +79,7 @@ def tesseract_dispatch_abstract_eval(
7879
return tuple(jax.core.ShapedArray(aval.shape, aval.dtype) for aval in output_avals)
7980

8081

82+
8183
def tesseract_dispatch_jvp_rule(
8284
in_args: tuple[ArrayLike, ...],
8385
tan_args: tuple[ArrayLike, ...],
@@ -190,6 +192,38 @@ def tesseract_dispatch_transpose_rule(
190192
ad.primitive_transposes[tesseract_dispatch_p] = tesseract_dispatch_transpose_rule
191193

192194

195+
def tesseract_dispatch(
196+
*array_args: ArrayLike | ShapedArray | Any,
197+
static_args: tuple[_Hashable, ...],
198+
input_pytreedef: PyTreeDef,
199+
output_pytreedef: PyTreeDef | None,
200+
output_avals: tuple[ShapeDtypeStruct, ...] | None,
201+
is_static_mask: tuple[bool, ...],
202+
client: Jaxeract,
203+
eval_func: str,
204+
) -> Any:
205+
"""Defines how to dispatch lowering the computation."""
206+
207+
def _dispatch(*args: ArrayLike) -> Any:
208+
static_args_ = tuple(_unpack_hashable(arg) for arg in static_args)
209+
out = getattr(client, eval_func)(
210+
args,
211+
static_args_,
212+
input_pytreedef,
213+
output_pytreedef,
214+
output_avals,
215+
is_static_mask,
216+
)
217+
if not isinstance(out, tuple):
218+
out = (out,)
219+
return out
220+
221+
result = _dispatch(*array_args)
222+
223+
return result
224+
225+
tesseract_dispatch_p.def_impl(tesseract_dispatch)
226+
193227
def tesseract_dispatch_lowering(
194228
ctx: Any,
195229
*array_args: ArrayLike | ShapedArray | Any,
@@ -406,4 +440,4 @@ def apply_tesseract(
406440
# )
407441

408442
# Unflatten the output
409-
return jax.tree.unflatten(output_pytreedef, out)
443+
return out#jax.tree.unflatten(output_pytreedef, out)

0 commit comments

Comments
 (0)