Skip to content

Commit 070c72e

Browse files
committed
revert to map instead of scan
1 parent b16537b commit 070c72e

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

tesseract_jax/primitive.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -250,12 +250,6 @@ def tesseract_dispatch_batching(
250250
for arg, ax in zip(array_args, axes, strict=True)
251251
]
252252

253-
# if output_pytreedef is not None:
254-
# output_pytreedef_expanded = tuple(
255-
# None if layout is None else tuple(n + 1 for n in layout) + (0,)
256-
# for layout in output_pytreedef
257-
# )
258-
259253
is_batched_mask = [d is not batching.not_mapped for d in axes]
260254
unbatched_args, batched_args = split_args(new_args, is_batched_mask)
261255

@@ -272,8 +266,7 @@ def _batch_fun(batched_args: tuple):
272266
eval_func=eval_func,
273267
)
274268

275-
g = lambda _, x: ((), _batch_fun(x))
276-
_, outvals = jax.lax.scan(g, (), batched_args)
269+
outvals = jax.lax.map(_batch_fun, batched_args)
277270

278271
return tuple(outvals), (0,) * len(outvals)
279272

0 commit comments

Comments
 (0)