2121
2222tesseract_dispatch_p = extend .core .Primitive ("tesseract_dispatch" )
2323tesseract_dispatch_p .multiple_results = True
24- tesseract_dispatch_p .def_impl (
25- functools .partial (xla .apply_primitive , tesseract_dispatch_p )
26- )
24+ # tesseract_dispatch_p.def_impl(
25+ # functools.partial(xla.apply_primitive, tesseract_dispatch_p)
26+ # )
2727
2828
2929class _Hashable :
@@ -58,6 +58,7 @@ def tesseract_dispatch_abstract_eval(
5858 client : Jaxeract ,
5959 eval_func : str ,
6060) -> tuple :
61+
6162 """Define how to dispatch evals and pipe arguments."""
6263 if eval_func not in (
6364 "apply" ,
@@ -78,6 +79,7 @@ def tesseract_dispatch_abstract_eval(
7879 return tuple (jax .core .ShapedArray (aval .shape , aval .dtype ) for aval in output_avals )
7980
8081
82+
8183def tesseract_dispatch_jvp_rule (
8284 in_args : tuple [ArrayLike , ...],
8385 tan_args : tuple [ArrayLike , ...],
@@ -190,6 +192,38 @@ def tesseract_dispatch_transpose_rule(
190192ad .primitive_transposes [tesseract_dispatch_p ] = tesseract_dispatch_transpose_rule
191193
192194
195+ def tesseract_dispatch (
196+ * array_args : ArrayLike | ShapedArray | Any ,
197+ static_args : tuple [_Hashable , ...],
198+ input_pytreedef : PyTreeDef ,
199+ output_pytreedef : PyTreeDef | None ,
200+ output_avals : tuple [ShapeDtypeStruct , ...] | None ,
201+ is_static_mask : tuple [bool , ...],
202+ client : Jaxeract ,
203+ eval_func : str ,
204+ ) -> Any :
205+ """Defines how to dispatch lowering the computation."""
206+
207+ def _dispatch (* args : ArrayLike ) -> Any :
208+ static_args_ = tuple (_unpack_hashable (arg ) for arg in static_args )
209+ out = getattr (client , eval_func )(
210+ args ,
211+ static_args_ ,
212+ input_pytreedef ,
213+ output_pytreedef ,
214+ output_avals ,
215+ is_static_mask ,
216+ )
217+ if not isinstance (out , tuple ):
218+ out = (out ,)
219+ return out
220+
221+ result = _dispatch (* array_args )
222+
223+ return result
224+
225+ tesseract_dispatch_p .def_impl (tesseract_dispatch )
226+
193227def tesseract_dispatch_lowering (
194228 ctx : Any ,
195229 * array_args : ArrayLike | ShapedArray | Any ,
@@ -406,4 +440,4 @@ def apply_tesseract(
406440 # )
407441
408442 # Unflatten the output
409- return jax .tree .unflatten (output_pytreedef , out )
443+ return out # jax.tree.unflatten(output_pytreedef, out)
0 commit comments