Skip to content

Commit 7c5871f

Browse files
[Pallas TPU] Hoist prologue and epilogue outside of pipeline loop
PiperOrigin-RevId: 738038138
1 parent 3094148 commit 7c5871f

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,9 +1196,8 @@ def pipeline(
11961196
schedule = map_brefs(
11971197
lambda _, x: get_pipeline_schedule(x), allocations, schedule)
11981198

1199-
def loop_body(step, indices):
1200-
nonlocal allocations
1201-
scheduler = Scheduler(
1199+
def make_scheduler(step, indices):
1200+
return Scheduler(
12021201
step,
12031202
indices,
12041203
grid,
@@ -1208,13 +1207,15 @@ def loop_body(step, indices):
12081207
init_accumulators=init_accumulators,
12091208
trace_scopes=trace_scopes,
12101209
)
1210+
1211+
def loop_body(step, indices):
1212+
scheduler = make_scheduler(step, indices)
12111213
with scheduler.grid_env():
12121214

12131215
# prepare any local VMEM aliases
12141216
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
12151217

12161218
# loop input handling phase
1217-
map_brefs(scheduler.initialize, brefs, refs, schedule)
12181219
map_brefs(scheduler.copy_in, brefs, refs, schedule)
12191220
map_brefs(scheduler.wait_in, brefs, refs, schedule)
12201221

@@ -1243,12 +1244,24 @@ def loop_body(step, indices):
12431244
lambda: None)
12441245

12451246
map_brefs(scheduler.swap_slots, brefs, refs, schedule)
1246-
map_brefs(scheduler.finalize, brefs, refs, schedule)
1247-
12481247
return _next_index(indices, grid)
12491248

1250-
# run pipeline
1251-
lax.fori_loop(0, num_steps, loop_body, (0,) * len(grid))
1249+
@pl.when(num_steps > 0)
1250+
def _():
1251+
# pipeline prologue
1252+
initial_indices = (0,) * len(grid)
1253+
scheduler = make_scheduler(0, initial_indices)
1254+
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
1255+
map_brefs(scheduler.initialize, brefs, refs, schedule)
1256+
1257+
# pipeline loop
1258+
next_indices = lax.fori_loop(0, num_steps, loop_body, initial_indices)
1259+
1260+
# pipeline epilogue
1261+
final_indices = _prev_index(next_indices, grid)
1262+
scheduler = make_scheduler(num_steps - 1, final_indices)
1263+
brefs = map_brefs(scheduler.alias_local_refs, allocations, refs)
1264+
map_brefs(scheduler.finalize, brefs, refs, schedule)
12521265

12531266
return pipeline
12541267

0 commit comments

Comments
 (0)