Skip to content

Commit 02505fa

Browse files
[Pallas TPU] Remove next_slot SMEM tensor from pipeline emitter
PiperOrigin-RevId: 735564365
1 parent 988a120 commit 02505fa

File tree

1 file changed

+87
-28
lines changed

1 file changed

+87
-28
lines changed

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 87 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -207,16 +207,23 @@ class BufferedRef:
207207
is_accumulator: whether this BufferedRef is an accumulator.
208208
is_input_output: whether this BufferedRef is an input/output without
209209
automatic accumulation.
210+
swap: Tracks whether the BufferedRef slots need to be swapped before next
211+
copy.
210212
"""
211213
spec: pl.BlockSpec # static metadata
212214
dtype: Any # static metadata
213215
buffer_type: BufferType # static metadata
214216
window_ref: REF | None
215217
accum_ref: REF | None
216218
current_slot: ArrayRef | None
219+
# TODO(ramiroleal): Unused by class. Remove argument from
220+
# BufferedRef instantiations.
217221
next_slot: ArrayRef | None
218222
sem_recvs: SemaphoreTuple | None
219223
sem_sends: SemaphoreTuple | None
224+
# TODO(ramiroleal): Improve prefetch/postyeet interface to avoid
225+
# using this ref.
226+
swap: ArrayRef | None
220227

221228
def tree_flatten(self):
222229
return (
@@ -227,6 +234,7 @@ def tree_flatten(self):
227234
self.next_slot,
228235
self.sem_recvs,
229236
self.sem_sends,
237+
self.swap,
230238
),
231239
(self.spec, self.dtype, self.buffer_type),
232240
)
@@ -240,14 +248,15 @@ def buffer_types() -> type[BufferType]:
240248
return BufferType
241249

242250
@classmethod
243-
def create(cls, spec, dtype, buffer_type) -> BufferedRef:
251+
def create(cls, spec, dtype, buffer_type, needs_swap_ref=True) -> BufferedRef:
244252
"""Create a BufferedRef.
245253
246254
Args:
247255
spec: pallas blockspec.
248256
dtype: dtype for buffers.
249257
buffer_type: enum indicating whether this is an input, output, or in/out
250258
accumulator buffered reference.
259+
needs_swap_ref: whether a swap slots tracker needs to be allocated.
251260
252261
Returns:
253262
Initialized BufferedRef
@@ -271,6 +280,7 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef:
271280
next_slot=None,
272281
sem_recvs=None,
273282
sem_sends=None,
283+
swap=None,
274284
)
275285
else:
276286
memory_space = SMEM if spec.memory_space == SMEM else VMEM
@@ -281,7 +291,7 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef:
281291
window_ref=memory_space((2,) + block_shape, dtype),
282292
accum_ref=accum_ref,
283293
current_slot=SMEM((1,), jnp.int32),
284-
next_slot=SMEM((1,), jnp.int32),
294+
next_slot=None,
285295
sem_recvs=(
286296
None
287297
if buffer_type is BufferType.OUTPUT
@@ -292,23 +302,24 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef:
292302
if buffer_type is BufferType.INPUT
293303
else SemaphoreType.DMA((2,))
294304
),
305+
swap=SMEM((1,), jnp.bool) if needs_swap_ref else None,
295306
)
296307

297308
@classmethod
298-
def input(cls, spec, dtype):
299-
return cls.create(spec, dtype, BufferType.INPUT)
309+
def input(cls, spec, dtype, needs_swap_ref=True):
310+
return cls.create(spec, dtype, BufferType.INPUT, needs_swap_ref)
300311

301312
@classmethod
302-
def output(cls, spec, dtype):
303-
return cls.create(spec, dtype, BufferType.OUTPUT)
313+
def output(cls, spec, dtype, needs_swap_ref=True):
314+
return cls.create(spec, dtype, BufferType.OUTPUT, needs_swap_ref)
304315

305316
@classmethod
306-
def accumulator(cls, spec, dtype):
307-
return cls.create(spec, dtype, BufferType.ACCUMULATOR)
317+
def accumulator(cls, spec, dtype, needs_swap_ref=True):
318+
return cls.create(spec, dtype, BufferType.ACCUMULATOR, needs_swap_ref)
308319

309320
@classmethod
310-
def input_output(cls, spec, dtype):
311-
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT)
321+
def input_output(cls, spec, dtype, needs_swap_ref=True):
322+
return cls.create(spec, dtype, BufferType.INPUT_OUTPUT, needs_swap_ref)
312323

313324
@property
314325
def block_shape(self):
@@ -329,7 +340,7 @@ def current_ref(self):
329340
if self.memory_space == VMEM:
330341
return self.window_ref.at[buffer_slice]
331342
else:
332-
return self.window_ref.at[(self.current_slot[0], *buffer_slice)]
343+
return self.window_ref.at[(self.current_slot_index, *buffer_slice)]
333344

334345
@property
335346
def is_input(self):
@@ -355,6 +366,14 @@ def is_accumulator(self):
355366
def is_input_output(self):
356367
return self.buffer_type == BufferType.INPUT_OUTPUT
357368

369+
@property
370+
def current_slot_index(self):
371+
return self.current_slot[0]
372+
373+
@property
374+
def next_slot_index(self):
375+
return lax.rem(self.current_slot_index + 1, 2)
376+
358377
def bind_existing_ref(self, window_ref, indices):
359378
"""For handling VMEM references, the pipeline aliases the existing ref."""
360379
if self.memory_space == VMEM:
@@ -373,12 +392,15 @@ def init_slots(self):
373392
"""Initialize slot indices."""
374393
if self.memory_space == VMEM: return
375394
self.current_slot[0] = 0
376-
self.next_slot[0] = 0
395+
if self.swap is not None:
396+
self.swap[0] = False
377397

378398
def swap_slots(self):
379399
"""Switch to the next slot."""
380400
if self.memory_space == VMEM: return
381-
self.current_slot[0] = self.next_slot[0]
401+
self.current_slot[0] = self.next_slot_index
402+
if self.swap is not None:
403+
self.swap[0] = False
382404

383405
def get_dma_slice(self, src_shape, src_dtype, grid_indices):
384406
# We need to handle blocks that might go OOB in the src array. An in bounds
@@ -441,8 +463,9 @@ def copy_in(self, src_ref, grid_indices):
441463
"""Starts copy of HBM dma slice into the current slot."""
442464
assert self.is_input
443465
if self.memory_space == VMEM: return
444-
next_slot = lax.rem(self.current_slot[0] + 1, 2)
445-
self.next_slot[0] = next_slot
466+
if self.swap is not None:
467+
self.swap[0] = True
468+
next_slot = self.next_slot_index
446469
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
447470
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
448471
tpu_primitives.make_async_copy(
@@ -455,8 +478,9 @@ def copy_out(self, dst_ref, grid_indices):
455478
"""Starts copy of HBM dma slice from the current slot."""
456479
assert self.is_output
457480
if self.memory_space == VMEM: return
458-
slot = self.current_slot[0]
459-
self.next_slot[0] = lax.rem(slot + 1, 2)
481+
if self.swap is not None:
482+
self.swap[0] = True
483+
slot = self.current_slot_index
460484
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
461485
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
462486
tpu_primitives.make_async_copy(
@@ -471,7 +495,7 @@ def wait_in(self, src_ref, grid_indices):
471495
if self.memory_space == VMEM: return
472496
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
473497
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
474-
current_slot = self.current_slot[0]
498+
current_slot = self.current_slot_index
475499
tpu_primitives.make_async_copy(
476500
src_ref.at[src_slice], # nb: doesn't matter
477501
self.window_ref.at[current_slot].at[
@@ -484,7 +508,8 @@ def wait_out(self, dst_ref, grid_indices):
484508
"""Waits for output copy to finish."""
485509
assert self.is_output
486510
if self.memory_space == VMEM: return
487-
prev_slot = lax.rem(self.current_slot[0] + 1, 2)
511+
# In a double buffer, previous slot is the same as next slot.
512+
prev_slot = self.next_slot_index
488513
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
489514
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
490515
tpu_primitives.make_async_copy(
@@ -671,10 +696,7 @@ def _init_slots():
671696
def _start():
672697
if buffered_ref.is_input:
673698
buffered_ref.copy_in(src_ref, self.indices)
674-
675-
# In the prologue this makes it so we wait on the prologue copy to finish.
676-
# In other iterations this is the regular swap.
677-
buffered_ref.swap_slots()
699+
buffered_ref.swap_slots()
678700

679701
def wait_in(self, buffered_ref, src_ref, schedule=None):
680702
if schedule is None:
@@ -780,9 +802,32 @@ def finalize(self, buffered_ref, dst_ref, schedule=None):
780802
@self._named_scope("ep_finalize")
781803
def _end():
782804
if buffered_ref.is_output:
783-
buffered_ref.swap_slots() # formally correct, not actually necessary.
784805
buffered_ref.wait_out(dst_ref, self.indices)
785806

807+
def swap_slots(self, buffered_ref, hbm_ref, schedule=None):
808+
if buffered_ref.swap is not None:
809+
swap = buffered_ref.swap[0]
810+
else:
811+
# If we are not using an SMEM `swap` tensor to keep track of
812+
# swaps needed, then all the copies into and out of BufferedRefs
813+
# are done by direct calls to the `copy_in` and `copy_out`
814+
# methods in the pipeline loop. To determine if the BufferedRef
815+
# needs a swap of slots, we recalculate the copy-in/copy-out
816+
# conditions.
817+
if schedule is None:
818+
schedule = _default_schedule
819+
pred_in = schedule["copy_in"](self, buffered_ref, hbm_ref)
820+
pred_out = schedule["copy_out"](self, buffered_ref, hbm_ref)
821+
822+
copied_in = pred_in & buffered_ref.is_input & ~self.last_step
823+
copied_out = pred_out & buffered_ref.is_output
824+
swap = copied_in | copied_out
825+
826+
@pl.when(swap)
827+
@self._named_scope("ep_swap")
828+
def _swap():
829+
buffered_ref.swap_slots()
830+
786831
# END SCHEDULE --------------------------------------------------------------
787832

788833

@@ -875,6 +920,7 @@ def make_pipeline_allocations(
875920
in_specs=None,
876921
out_specs=None,
877922
should_accumulate_out=False,
923+
needs_swap_ref=True,
878924
):
879925
"""Create BufferedRefs for the pipeline.
880926
@@ -887,6 +933,7 @@ def make_pipeline_allocations(
887933
out_specs: output pallas block specs
888934
should_accumulate_out: booleans to indicate which outputs should be treated
889935
as accumulators.
936+
needs_swap_ref: whether a swap slots tracker needs to be allocated.
890937
891938
Returns:
892939
A list of BufferedRefs, one corresponding to each ref specified in the
@@ -905,12 +952,12 @@ def make_pipeline_allocations(
905952
in_refs = refs[:num_in_specs]
906953
out_refs = refs[num_in_specs:]
907954
def make_input_bref(in_spec, in_ref):
908-
return BufferedRef.input(in_spec, in_ref.dtype)
955+
return BufferedRef.input(in_spec, in_ref.dtype, needs_swap_ref)
909956
in_brefs = jax.tree.map(make_input_bref, in_specs, in_refs)
910957
def make_output_bref(out_spec, out_ref, accumulate):
911958
if accumulate:
912-
return BufferedRef.accumulator(out_spec, out_ref.dtype)
913-
return BufferedRef.output(out_spec, out_ref.dtype)
959+
return BufferedRef.accumulator(out_spec, out_ref.dtype, needs_swap_ref)
960+
return BufferedRef.output(out_spec, out_ref.dtype, needs_swap_ref)
914961
out_brefs = jax.tree.map(
915962
make_output_bref, out_specs, out_refs, should_accumulate_out)
916963
return (*in_brefs, *out_brefs)
@@ -1109,6 +1156,14 @@ def pipeline(
11091156
scratches = ()
11101157
if allocations is None:
11111158
# run with inline scoped allocations
1159+
1160+
# Prefetch and postyeet are arbitrary functions that can copy
1161+
# into or out of any of the BufferedRefs. Thus, we need a ref
1162+
# for the scheduler to mark when the prefetch or postyeet
1163+
# functions perform a copy and the slots need to be
1164+
# swapped. Without prefetch and postyeet, the swapping logic can
1165+
# be performed without the need for state.
1166+
needs_swap_ref = prefetch is not None or postyeet is not None
11121167
return primitives.run_scoped(
11131168
lambda allocations: pipeline(
11141169
*refs,
@@ -1125,7 +1180,9 @@ def pipeline(
11251180
*refs,
11261181
in_specs=in_specs,
11271182
out_specs=out_specs,
1128-
should_accumulate_out=should_accumulate_out),
1183+
should_accumulate_out=should_accumulate_out,
1184+
needs_swap_ref=needs_swap_ref,
1185+
),
11291186
)
11301187
if isinstance(allocations, list):
11311188
allocations = tuple(allocations)
@@ -1184,6 +1241,8 @@ def loop_body(step, indices):
11841241
lax.cond(step == 0,
11851242
lambda: postyeet(*brefs, scheduler),
11861243
lambda: None)
1244+
1245+
map_brefs(scheduler.swap_slots, brefs, refs, schedule)
11871246
map_brefs(scheduler.finalize, brefs, refs, schedule)
11881247

11891248
return _next_index(indices, grid)

0 commit comments

Comments
 (0)