Skip to content

Commit a6bcc1a

Browse files
committed
feat: support simple sequential batching rule
1 parent 7a6db32 commit a6bcc1a

File tree

2 files changed

+79
-34
lines changed

2 files changed

+79
-34
lines changed

tesseract_jax/primitive.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@
33

44
import functools
55
from collections.abc import Sequence
6-
from typing import Any
6+
from typing import Any, TypeVar
77

88
import jax.tree
99
import numpy as np
1010
from jax import ShapeDtypeStruct, dtypes, extend
1111
from jax.core import ShapedArray
12-
from jax.interpreters import ad, mlir, xla
12+
from jax.interpreters import ad, batching, mlir, xla
1313
from jax.tree_util import PyTreeDef
1414
from jax.typing import ArrayLike
1515
from tesseract_core import Tesseract
1616

17-
from tesseract_jax.tesseract_compat import Jaxeract
17+
from tesseract_jax.tesseract_compat import Jaxeract, combine_args
18+
19+
T = TypeVar("T")
1820

1921
tesseract_dispatch_p = extend.core.Primitive("tesseract_dispatch")
2022
tesseract_dispatch_p.multiple_results = True
@@ -35,21 +37,13 @@ def __hash__(self) -> int:
3537

3638

3739
def split_args(
38-
flat_args: Sequence[Any], is_static_mask: Sequence[bool]
39-
) -> tuple[tuple[ArrayLike, ...], tuple[_Hashable, ...]]:
40-
"""Split a flat argument list into a tuple (array_args, static_args)."""
41-
static_args = tuple(
42-
_make_hashable(arg)
43-
for arg, is_static in zip(flat_args, is_static_mask, strict=True)
44-
if is_static
45-
)
46-
array_args = tuple(
47-
arg
48-
for arg, is_static in zip(flat_args, is_static_mask, strict=True) # fmt: skip
49-
if not is_static
50-
)
51-
52-
return array_args, static_args
40+
flat_args: Sequence[T], mask: Sequence[bool]
41+
) -> tuple[tuple[T, ...], tuple[T, ...]]:
42+
"""Split a flat argument tuple according to mask (mask_False, mask_True)."""
43+
lists = ([], [])
44+
for a, m in zip(flat_args, mask, strict=True):
45+
lists[m].append(a)
46+
return tuple(tuple(args) for args in lists)
5347

5448

5549
@tesseract_dispatch_p.def_abstract_eval
@@ -238,6 +232,55 @@ def _dispatch(*args: ArrayLike) -> Any:
238232
mlir.register_lowering(tesseract_dispatch_p, tesseract_dispatch_lowering)
239233

240234

235+
def tesseract_dispatch_batching(
236+
array_args: ArrayLike | ShapedArray | Any,
237+
axes: Sequence[Any],
238+
*,
239+
static_args: tuple[_Hashable, ...],
240+
input_pytreedef: PyTreeDef,
241+
output_pytreedef: PyTreeDef,
242+
output_avals: tuple[ShapeDtypeStruct, ...],
243+
is_static_mask: tuple[bool, ...],
244+
client: Jaxeract,
245+
eval_func: str,
246+
) -> Any:
247+
"""Defines how to dispatch batch operations such as vmap (which is used by jax.jacobian)."""
248+
new_args = [
249+
arg if ax is batching.not_mapped else batching.moveaxis(arg, ax, 0)
250+
for arg, ax in zip(array_args, axes, strict=True)
251+
]
252+
253+
# if output_pytreedef is not None:
254+
# output_pytreedef_expanded = tuple(
255+
# None if layout is None else tuple(n + 1 for n in layout) + (0,)
256+
# for layout in output_pytreedef
257+
# )
258+
259+
is_batched_mask = [d is not batching.not_mapped for d in axes]
260+
unbatched_args, batched_args = split_args(new_args, is_batched_mask)
261+
262+
def _batch_fun(batched_args: tuple):
263+
combined_args = combine_args(unbatched_args, batched_args, is_batched_mask)
264+
return tesseract_dispatch_p.bind(
265+
*combined_args,
266+
static_args=static_args,
267+
input_pytreedef=input_pytreedef,
268+
output_pytreedef=output_pytreedef,
269+
output_avals=output_avals,
270+
is_static_mask=is_static_mask,
271+
client=client,
272+
eval_func=eval_func,
273+
)
274+
275+
g = lambda _, x: ((), _batch_fun(x))
276+
_, outvals = jax.lax.scan(g, (), batched_args)
277+
278+
return tuple(outvals), (0,) * len(outvals)
279+
280+
281+
batching.primitive_batchers[tesseract_dispatch_p] = tesseract_dispatch_batching
282+
283+
241284
def _check_dtype(dtype: Any) -> None:
242285
dt = np.dtype(dtype)
243286
if dtypes.canonicalize_dtype(dt) != dt:
@@ -316,8 +359,10 @@ def apply_tesseract(
316359
client = Jaxeract(tesseract_client)
317360

318361
flat_args, input_pytreedef = jax.tree.flatten(inputs)
362+
319363
is_static_mask = tuple(_is_static(arg) for arg in flat_args)
320364
array_args, static_args = split_args(flat_args, is_static_mask)
365+
static_args = tuple(_make_hashable(arg) for arg in static_args)
321366

322367
# Get abstract values for outputs, so we can unflatten them later
323368
output_pytreedef, avals = None, None

tesseract_jax/tesseract_compat.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright 2025 Pasteur Labs. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from collections.abc import Sequence
45
from typing import Any, TypeAlias
56

67
import jax.tree
@@ -12,6 +13,14 @@
1213
PyTree: TypeAlias = Any
1314

1415

16+
def combine_args(args0: Sequence, args1: Sequence, mask: Sequence[bool]) -> tuple:
17+
"""Merge the elements of two lists based on a mask."""
18+
assert sum(mask) == len(args1) and len(mask) - sum(mask) == len(args0)
19+
args0_iter, args1_iter = iter(args0), iter(args1)
20+
combined_args = [next(args1_iter) if m else next(args0_iter) for m in mask]
21+
return tuple(combined_args)
22+
23+
1524
def unflatten_args(
1625
array_args: tuple[ArrayLike, ...],
1726
static_args: tuple[Any, ...],
@@ -20,23 +29,14 @@ def unflatten_args(
2029
remove_static_args: bool = False,
2130
) -> PyTree:
2231
"""Unflatten lists of arguments (static and not) into a pytree."""
23-
combined_args = []
24-
static_iter = iter(static_args)
25-
array_iter = iter(array_args)
26-
27-
for is_static in is_static_mask:
28-
if is_static:
29-
elem = next(static_iter)
30-
elem = elem.wrapped if hasattr(elem, "wrapped") else elem
31-
32-
if remove_static_args:
33-
combined_args.append(None)
34-
else:
35-
combined_args.append(elem)
36-
37-
else:
38-
combined_args.append(next(array_iter))
32+
if remove_static_args:
33+
static_args_converted = [None] * len(static_args)
34+
else:
35+
static_args_converted = [
36+
elem.wrapped if hasattr(elem, "wrapped") else elem for elem in static_args
37+
]
3938

39+
combined_args = combine_args(array_args, static_args_converted, is_static_mask)
4040
result = jax.tree.unflatten(input_pytreedef, combined_args)
4141

4242
if remove_static_args:

0 commit comments

Comments
 (0)