Skip to content

Commit d52ef6b

Browse files
feat: support simple sequential batching rule (support for jax.vmap and jax.jacobian) (#47)
#### Relevant issue or PR Fixes #1 #### Description of changes Implement a simple sequential batching rule enabling the use of `vmap` and `jax.jacobian` by simply calling the relevant endpoint multiple times. The implementation was directly lifted from [jax._src.ffi.ffi_batching_rule](https://github.com/jax-ml/jax/blob/main/jax/_src/ffi.py#L638) with refactoring of variable names and eliminating dead code. No private JAX API is used; instead `split_args` and `unflatten_args` were refactored to serve additional purposes. #### Testing done - [x] CI passes - [x] `jax.jacfwd` and `jax.jacrev` runs and produces expected results with `vectoradd_jax` - [x] Added CI tests of `jacobian` - [x] Added wide range of `vmap` tests to CI --------- Co-authored-by: Dion Häfner <[email protected]>
1 parent 42fc068 commit d52ef6b

File tree

4 files changed

+306
-36
lines changed

4 files changed

+306
-36
lines changed

tesseract_jax/primitive.py

Lines changed: 55 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,20 @@
33

44
import functools
55
from collections.abc import Sequence
6-
from typing import Any
6+
from typing import Any, TypeVar
77

88
import jax.tree
99
import numpy as np
1010
from jax import ShapeDtypeStruct, dtypes, extend
1111
from jax.core import ShapedArray
12-
from jax.interpreters import ad, mlir, xla
12+
from jax.interpreters import ad, batching, mlir, xla
1313
from jax.tree_util import PyTreeDef
1414
from jax.typing import ArrayLike
1515
from tesseract_core import Tesseract
1616

17-
from tesseract_jax.tesseract_compat import Jaxeract
17+
from tesseract_jax.tesseract_compat import Jaxeract, combine_args
18+
19+
T = TypeVar("T")
1820

1921
tesseract_dispatch_p = extend.core.Primitive("tesseract_dispatch")
2022
tesseract_dispatch_p.multiple_results = True
@@ -35,21 +37,13 @@ def __hash__(self) -> int:
3537

3638

3739
def split_args(
38-
flat_args: Sequence[Any], is_static_mask: Sequence[bool]
39-
) -> tuple[tuple[ArrayLike, ...], tuple[_Hashable, ...]]:
40-
"""Split a flat argument list into a tuple (array_args, static_args)."""
41-
static_args = tuple(
42-
_make_hashable(arg)
43-
for arg, is_static in zip(flat_args, is_static_mask, strict=True)
44-
if is_static
45-
)
46-
array_args = tuple(
47-
arg
48-
for arg, is_static in zip(flat_args, is_static_mask, strict=True)
49-
if not is_static
50-
)
51-
52-
return array_args, static_args
40+
flat_args: Sequence[T], mask: Sequence[bool]
41+
) -> tuple[tuple[T, ...], tuple[T, ...]]:
42+
"""Split a flat argument tuple according to mask (mask_False, mask_True)."""
43+
lists = ([], [])
44+
for a, m in zip(flat_args, mask, strict=True):
45+
lists[m].append(a)
46+
return tuple(tuple(args) for args in lists)
5347

5448

5549
@tesseract_dispatch_p.def_abstract_eval
@@ -238,6 +232,48 @@ def _dispatch(*args: ArrayLike) -> Any:
238232
mlir.register_lowering(tesseract_dispatch_p, tesseract_dispatch_lowering)
239233

240234

235+
def tesseract_dispatch_batching(
236+
array_args: ArrayLike | ShapedArray | Any,
237+
axes: Sequence[Any],
238+
*,
239+
static_args: tuple[_Hashable, ...],
240+
input_pytreedef: PyTreeDef,
241+
output_pytreedef: PyTreeDef,
242+
output_avals: tuple[ShapeDtypeStruct, ...],
243+
is_static_mask: tuple[bool, ...],
244+
client: Jaxeract,
245+
eval_func: str,
246+
) -> Any:
247+
"""Defines how to dispatch batch operations such as vmap (which is used by jax.jacobian)."""
248+
new_args = [
249+
arg if ax is batching.not_mapped else batching.moveaxis(arg, ax, 0)
250+
for arg, ax in zip(array_args, axes, strict=True)
251+
]
252+
253+
is_batched_mask = [d is not batching.not_mapped for d in axes]
254+
unbatched_args, batched_args = split_args(new_args, is_batched_mask)
255+
256+
def _batch_fun(batched_args: tuple):
257+
combined_args = combine_args(unbatched_args, batched_args, is_batched_mask)
258+
return tesseract_dispatch_p.bind(
259+
*combined_args,
260+
static_args=static_args,
261+
input_pytreedef=input_pytreedef,
262+
output_pytreedef=output_pytreedef,
263+
output_avals=output_avals,
264+
is_static_mask=is_static_mask,
265+
client=client,
266+
eval_func=eval_func,
267+
)
268+
269+
outvals = jax.lax.map(_batch_fun, batched_args)
270+
271+
return tuple(outvals), (0,) * len(outvals)
272+
273+
274+
batching.primitive_batchers[tesseract_dispatch_p] = tesseract_dispatch_batching
275+
276+
241277
def _check_dtype(dtype: Any) -> None:
242278
dt = np.dtype(dtype)
243279
if dtypes.canonicalize_dtype(dt) != dt:
@@ -318,6 +354,7 @@ def apply_tesseract(
318354
flat_args, input_pytreedef = jax.tree.flatten(inputs)
319355
is_static_mask = tuple(_is_static(arg) for arg in flat_args)
320356
array_args, static_args = split_args(flat_args, is_static_mask)
357+
static_args = tuple(_make_hashable(arg) for arg in static_args)
321358

322359
# Get abstract values for outputs, so we can unflatten them later
323360
output_pytreedef, avals = None, None

tesseract_jax/tesseract_compat.py

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Copyright 2025 Pasteur Labs. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
33

4+
from collections.abc import Sequence
45
from typing import Any, TypeAlias
56

67
import jax.tree
@@ -12,6 +13,24 @@
1213
PyTree: TypeAlias = Any
1314

1415

16+
def combine_args(args0: Sequence, args1: Sequence, mask: Sequence[bool]) -> tuple:
17+
"""Merge the elements of two lists based on a mask.
18+
19+
The length of the two lists is required to be equal to the length of the mask.
20+
`combine_args` will populate the new list according to the mask: if the mask evaluates
21+
to `False` it will take the next item of the first list, if it evaluate to `True` it will
22+
take from the second list.
23+
24+
Example:
25+
>>> combine_args(["foo", "bar"], [0, 1, 2], [1, 0, 0, 1, 1])
26+
[0, "foo", "bar", 1, 2]
27+
"""
28+
assert sum(mask) == len(args1) and len(mask) - sum(mask) == len(args0)
29+
args0_iter, args1_iter = iter(args0), iter(args1)
30+
combined_args = [next(args1_iter) if m else next(args0_iter) for m in mask]
31+
return tuple(combined_args)
32+
33+
1534
def unflatten_args(
1635
array_args: tuple[ArrayLike, ...],
1736
static_args: tuple[Any, ...],
@@ -20,23 +39,14 @@ def unflatten_args(
2039
remove_static_args: bool = False,
2140
) -> PyTree:
2241
"""Unflatten lists of arguments (static and not) into a pytree."""
23-
combined_args = []
24-
static_iter = iter(static_args)
25-
array_iter = iter(array_args)
26-
27-
for is_static in is_static_mask:
28-
if is_static:
29-
elem = next(static_iter)
30-
elem = elem.wrapped if hasattr(elem, "wrapped") else elem
31-
32-
if remove_static_args:
33-
combined_args.append(None)
34-
else:
35-
combined_args.append(elem)
36-
37-
else:
38-
combined_args.append(next(array_iter))
42+
if remove_static_args:
43+
static_args_converted = [None] * len(static_args)
44+
else:
45+
static_args_converted = [
46+
elem.wrapped if hasattr(elem, "wrapped") else elem for elem in static_args
47+
]
3948

49+
combined_args = combine_args(array_args, static_args_converted, is_static_mask)
4050
result = jax.tree.unflatten(input_pytreedef, combined_args)
4151

4252
if remove_static_args:

tests/nested_tesseract/tesseract_api.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,17 @@ def vector_jacobian_product(
8181
return out
8282

8383

84+
def jacobian(inputs: InputSchema, jac_inputs: set[str], jac_outputs: set[str]):
85+
jac = {dy: {dx: [0.0, 0.0, 0.0] for dx in jac_inputs} for dy in jac_outputs}
86+
87+
if "scalars.a" in jac_inputs and "scalars.a" in jac_outputs:
88+
jac["scalars.a"]["scalars.a"] = 10.0
89+
if "vectors.v" in jac_inputs and "vectors.v" in jac_outputs:
90+
jac["vectors.v"]["vectors.v"] = [[10.0, 0, 0], [0, 10.0, 0], [0, 0, 10.0]]
91+
92+
return jac
93+
94+
8495
def abstract_eval(abstract_inputs):
8596
"""Calculate output shape of apply from the shape of its inputs."""
8697
return {

0 commit comments

Comments
 (0)