@@ -1196,9 +1196,8 @@ def pipeline(
11961196 schedule = map_brefs (
11971197 lambda _ , x : get_pipeline_schedule (x ), allocations , schedule )
11981198
1199- def loop_body (step , indices ):
1200- nonlocal allocations
1201- scheduler = Scheduler (
1199+ def make_scheduler (step , indices ):
1200+ return Scheduler (
12021201 step ,
12031202 indices ,
12041203 grid ,
@@ -1208,13 +1207,15 @@ def loop_body(step, indices):
12081207 init_accumulators = init_accumulators ,
12091208 trace_scopes = trace_scopes ,
12101209 )
1210+
1211+ def loop_body (step , indices ):
1212+ scheduler = make_scheduler (step , indices )
12111213 with scheduler .grid_env ():
12121214
12131215 # prepare any local VMEM aliases
12141216 brefs = map_brefs (scheduler .alias_local_refs , allocations , refs )
12151217
12161218 # loop input handling phase
1217- map_brefs (scheduler .initialize , brefs , refs , schedule )
12181219 map_brefs (scheduler .copy_in , brefs , refs , schedule )
12191220 map_brefs (scheduler .wait_in , brefs , refs , schedule )
12201221
@@ -1243,12 +1244,24 @@ def loop_body(step, indices):
12431244 lambda : None )
12441245
12451246 map_brefs (scheduler .swap_slots , brefs , refs , schedule )
1246- map_brefs (scheduler .finalize , brefs , refs , schedule )
1247-
12481247 return _next_index (indices , grid )
12491248
1250- # run pipeline
1251- lax .fori_loop (0 , num_steps , loop_body , (0 ,) * len (grid ))
1249+ @pl .when (num_steps > 0 )
1250+ def _ ():
1251+ # pipeline prologue
1252+ initial_indices = (0 ,) * len (grid )
1253+ scheduler = make_scheduler (0 , initial_indices )
1254+ brefs = map_brefs (scheduler .alias_local_refs , allocations , refs )
1255+ map_brefs (scheduler .initialize , brefs , refs , schedule )
1256+
1257+ # pipeline loop
1258+ next_indices = lax .fori_loop (0 , num_steps , loop_body , initial_indices )
1259+
1260+ # pipeline epilogue
1261+ final_indices = _prev_index (next_indices , grid )
1262+ scheduler = make_scheduler (num_steps - 1 , final_indices )
1263+ brefs = map_brefs (scheduler .alias_local_refs , allocations , refs )
1264+ map_brefs (scheduler .finalize , brefs , refs , schedule )
12521265
12531266 return pipeline
12541267
0 commit comments