File tree Expand file tree Collapse file tree 1 file changed +1
-8
lines changed Expand file tree Collapse file tree 1 file changed +1
-8
lines changed Original file line number Diff line number Diff line change @@ -250,12 +250,6 @@ def tesseract_dispatch_batching(
250250 for arg , ax in zip (array_args , axes , strict = True )
251251 ]
252252
253- # if output_pytreedef is not None:
254- # output_pytreedef_expanded = tuple(
255- # None if layout is None else tuple(n + 1 for n in layout) + (0,)
256- # for layout in output_pytreedef
257- # )
258-
259253 is_batched_mask = [d is not batching .not_mapped for d in axes ]
260254 unbatched_args , batched_args = split_args (new_args , is_batched_mask )
261255
@@ -272,8 +266,7 @@ def _batch_fun(batched_args: tuple):
272266 eval_func = eval_func ,
273267 )
274268
275- g = lambda _ , x : ((), _batch_fun (x ))
276- _ , outvals = jax .lax .scan (g , (), batched_args )
269+ outvals = jax .lax .map (_batch_fun , batched_args )
277270
278271 return tuple (outvals ), (0 ,) * len (outvals )
279272
You can’t perform that action at this time.
0 commit comments