3030from jax ._src .pallas .mosaic_gpu import core as gpu_core
3131from jax ._src .pallas .mosaic_gpu import primitives as gpu_primitives
3232from jax .experimental import pallas as pl
33+ import jax .numpy as jnp
3334
3435
3536map = util .safe_map
@@ -72,15 +73,16 @@ def copy_out(self, slot, grid_indices):
7273)
7374
7475
75- def make_grid_indices (
76- step : jax .typing . ArrayLike , grid : Sequence [int ]
76+ def _inc_grid_by_1 (
77+ indices : tuple [ jax .Array , ...] , grid : Sequence [int ]
7778) -> tuple [jax .Array , ...]:
78- # TODO(slebedev): Maintain the grid index through the fori_loop instead.
79- indices = []
80- for size in reversed (grid ):
81- indices .append (lax .rem (step , size ))
82- step = lax .div (step , size )
83- return tuple (reversed (indices ))
79+ next_indices = []
80+ carry : bool | jax .Array = True
81+ for idx , size in reversed (list (zip (indices , grid ))):
82+ next_idx = lax .select (carry , idx + 1 , idx )
83+ carry = next_idx == size
84+ next_indices .append (lax .select (carry , 0 , next_idx ).astype (idx .dtype ))
85+ return tuple (reversed (next_indices ))
8486
8587
8688def emit_pipeline (
@@ -143,15 +145,15 @@ def scoped_pipeline(
143145 ):
144146 map (lambda bref : bref .copy_in (step , indices , barrier_ref ), in_brefs )
145147
146- def loop_body (step , _ ):
148+ def loop_body (step , carry ):
147149 slot = step % max_concurrent_steps
150+ indices , fetch_indices = carry
148151
149152 # Wait for the current GMEM->SMEM copy to complete.
150153 gpu_primitives .barrier_wait (barrier_ref .at [slot ])
151154 # Wait for the previous output SMEM->GMEM copy to complete.
152155 gpu_primitives .wait_smem_to_gmem (max_concurrent_steps - 1 )
153156
154- indices = make_grid_indices (step , grid )
155157 with pallas_core .grid_env (map (pallas_core .GridAxis , indices , grid )):
156158 body (
157159 * (bref .smem_ref .at [slot ] for bref in it .chain (in_brefs , out_brefs ))
@@ -166,17 +168,19 @@ def loop_body(step, _):
166168 jax .lax .cond (
167169 fetch_step < num_steps ,
168170 lambda : map (
169- lambda bref : bref .copy_in (
170- fetch_slot , make_grid_indices (fetch_step , grid ), barrier_ref
171- ),
171+ lambda bref : bref .copy_in (fetch_slot , fetch_indices , barrier_ref ),
172172 in_brefs ,
173173 ),
174174 lambda : [None ] * len (in_brefs ),
175175 )
176176
177- return ( )
177+ return _inc_grid_by_1 ( indices , grid ), _inc_grid_by_1 ( fetch_indices , grid )
178178
179- lax .fori_loop (0 , num_steps , loop_body , ())
179+ indices = (jnp .asarray (0 , dtype = lax .dtype (0 )),) * len (grid )
180+ fetch_indices = indices
181+ for _ in range (max_concurrent_steps ):
182+ fetch_indices = _inc_grid_by_1 (fetch_indices , grid )
183+ lax .fori_loop (0 , num_steps , loop_body , (indices , fetch_indices ))
180184
181185 # Finalize the pipeline.
182186 gpu_primitives .wait_smem_to_gmem (0 )
0 commit comments