33
44import functools
55from collections .abc import Sequence
6- from typing import Any
6+ from typing import Any , TypeVar
77
88import jax .tree
99import numpy as np
1010from jax import ShapeDtypeStruct , dtypes , extend
1111from jax .core import ShapedArray
12- from jax .interpreters import ad , mlir , xla
12+ from jax .interpreters import ad , batching , mlir , xla
1313from jax .tree_util import PyTreeDef
1414from jax .typing import ArrayLike
1515from tesseract_core import Tesseract
1616
17- from tesseract_jax .tesseract_compat import Jaxeract
17+ from tesseract_jax .tesseract_compat import Jaxeract , combine_args
18+
19+ T = TypeVar ("T" )
1820
1921tesseract_dispatch_p = extend .core .Primitive ("tesseract_dispatch" )
2022tesseract_dispatch_p .multiple_results = True
@@ -35,21 +37,13 @@ def __hash__(self) -> int:
3537
3638
3739def split_args (
38- flat_args : Sequence [Any ], is_static_mask : Sequence [bool ]
39- ) -> tuple [tuple [ArrayLike , ...], tuple [_Hashable , ...]]:
40- """Split a flat argument list into a tuple (array_args, static_args)."""
41- static_args = tuple (
42- _make_hashable (arg )
43- for arg , is_static in zip (flat_args , is_static_mask , strict = True )
44- if is_static
45- )
46- array_args = tuple (
47- arg
48- for arg , is_static in zip (flat_args , is_static_mask , strict = True ) # fmt: skip
49- if not is_static
50- )
51-
52- return array_args , static_args
40+ flat_args : Sequence [T ], mask : Sequence [bool ]
41+ ) -> tuple [tuple [T , ...], tuple [T , ...]]:
42+ """Split a flat argument tuple according to mask (mask_False, mask_True)."""
43+ lists = ([], [])
44+ for a , m in zip (flat_args , mask , strict = True ):
45+ lists [m ].append (a )
46+ return tuple (tuple (args ) for args in lists )
5347
5448
5549@tesseract_dispatch_p .def_abstract_eval
@@ -238,6 +232,55 @@ def _dispatch(*args: ArrayLike) -> Any:
238232mlir .register_lowering (tesseract_dispatch_p , tesseract_dispatch_lowering )
239233
240234
235+ def tesseract_dispatch_batching (
236+ array_args : ArrayLike | ShapedArray | Any ,
237+ axes : Sequence [Any ],
238+ * ,
239+ static_args : tuple [_Hashable , ...],
240+ input_pytreedef : PyTreeDef ,
241+ output_pytreedef : PyTreeDef ,
242+ output_avals : tuple [ShapeDtypeStruct , ...],
243+ is_static_mask : tuple [bool , ...],
244+ client : Jaxeract ,
245+ eval_func : str ,
246+ ) -> Any :
247+ """Defines how to dispatch batch operations such as vmap (which is used by jax.jacobian)."""
248+ new_args = [
249+ arg if ax is batching .not_mapped else batching .moveaxis (arg , ax , 0 )
250+ for arg , ax in zip (array_args , axes , strict = True )
251+ ]
252+
253+ # if output_pytreedef is not None:
254+ # output_pytreedef_expanded = tuple(
255+ # None if layout is None else tuple(n + 1 for n in layout) + (0,)
256+ # for layout in output_pytreedef
257+ # )
258+
259+ is_batched_mask = [d is not batching .not_mapped for d in axes ]
260+ unbatched_args , batched_args = split_args (new_args , is_batched_mask )
261+
262+ def _batch_fun (batched_args : tuple ):
263+ combined_args = combine_args (unbatched_args , batched_args , is_batched_mask )
264+ return tesseract_dispatch_p .bind (
265+ * combined_args ,
266+ static_args = static_args ,
267+ input_pytreedef = input_pytreedef ,
268+ output_pytreedef = output_pytreedef ,
269+ output_avals = output_avals ,
270+ is_static_mask = is_static_mask ,
271+ client = client ,
272+ eval_func = eval_func ,
273+ )
274+
275+ g = lambda _ , x : ((), _batch_fun (x ))
276+ _ , outvals = jax .lax .scan (g , (), batched_args )
277+
278+ return tuple (outvals ), (0 ,) * len (outvals )
279+
280+
281+ batching .primitive_batchers [tesseract_dispatch_p ] = tesseract_dispatch_batching
282+
283+
241284def _check_dtype (dtype : Any ) -> None :
242285 dt = np .dtype (dtype )
243286 if dtypes .canonicalize_dtype (dt ) != dt :
@@ -316,8 +359,10 @@ def apply_tesseract(
316359 client = Jaxeract (tesseract_client )
317360
318361 flat_args , input_pytreedef = jax .tree .flatten (inputs )
362+
319363 is_static_mask = tuple (_is_static (arg ) for arg in flat_args )
320364 array_args , static_args = split_args (flat_args , is_static_mask )
365+ static_args = tuple (_make_hashable (arg ) for arg in static_args )
321366
322367 # Get abstract values for outputs, so we can unflatten them later
323368 output_pytreedef , avals = None , None
0 commit comments