@@ -125,20 +125,41 @@ def __eq__(self, other: _Slice) -> jax.Array: # type: ignore
125125
126126
127127def emit_pipeline (
128- body ,
128+ body : Callable [..., None ] ,
129129 * ,
130130 grid : pallas_core .StaticGrid ,
131131 in_specs : Sequence [pallas_core .BlockSpec ] = (),
132132 out_specs : Sequence [pallas_core .BlockSpec ] = (),
133133 max_concurrent_steps : int = 1 ,
134+ delay_release : int = 0 ,
134135):
135- """Creates a function to emit a manual pipeline within a Pallas kernel."""
136+ """Creates a function to emit a manual pipeline within a Pallas kernel.
137+
138+ Args:
139+ body: The pipeline body.
140+ grid: The grid to use for the pipeline.
141+ in_specs: The block specs for the inputs.
142+ out_specs: The block specs for the outputs.
143+ max_concurrent_steps: The maximum number of sequential stages that are
144+ active concurrently. Defaults to 1.
145+ delay_release: The number of steps to wait before reusing the input/output
146+ references. Defaults to 0, and must be strictly smaller than
147+ ``max_concurrent_steps``. Generally, you'll want to set it to 1 if you
148+ don't await the WGMMA in the body.
149+ """
136150 num_steps = math .prod (grid )
137151
152+ if max_concurrent_steps <= delay_release :
153+ raise ValueError (
154+ "max_concurrent_steps must be greater than delay_release, but"
155+ f" { max_concurrent_steps = } , { delay_release = } "
156+ )
157+
138158 # Shrink ``max_concurrent_steps`` if the total number of steps is lower to
139- # reduce the size of the allocated buffers below .
159+ # reduce the size of the refs allocated in SMEM .
140160 if max_concurrent_steps > num_steps :
141161 max_concurrent_steps = num_steps
162+ delay_release = 0 # No need to delay anything.
142163
143164 def pipeline (* gmem_refs : pallas_core .AbstractMemoryRef ):
144165 for gmem_ref , spec in zip (gmem_refs , it .chain (in_specs , out_specs )):
@@ -208,7 +229,7 @@ def loop_body(step, carry):
208229 gpu_primitives .barrier_wait (barrier_ref .at [slot ])
209230 # Wait for the previous output SMEM->GMEM copy to complete.
210231 gpu_primitives .wait_smem_to_gmem (
211- max_concurrent_steps - 1 , wait_read_only = True
232+ max_concurrent_steps - ( 1 + delay_release ) , wait_read_only = True
212233 )
213234
214235 with pallas_core .grid_env (map (pallas_core .GridAxis , indices , grid )):
@@ -245,10 +266,10 @@ def loop_body(step, carry):
245266 predicate = lax .bitwise_or (slices_changed , is_last_step ),
246267 )
247268
248- fetch_step = step + max_concurrent_steps
269+ fetch_step = step + ( max_concurrent_steps - delay_release )
249270 fetch_slot = slot # (x + y) % y == x % y
250271 jax .lax .cond (
251- fetch_step < num_steps ,
272+ lax . bitwise_and ( fetch_step >= delay_release , fetch_step < num_steps ) ,
252273 lambda : map (
253274 lambda bref : bref .copy_in (fetch_slot , fetch_indices , barrier_ref ),
254275 in_brefs ,
0 commit comments