@@ -207,16 +207,23 @@ class BufferedRef:
207207 is_accumulator: whether this BufferedRef is an accumulator.
208208 is_input_output: whether this BufferedRef is an input/output without
209209 automatic accumulation.
210+ swap: Tracks whether the BufferedRef slots need to be swapped before next
211+ copy.
210212 """
211213 spec : pl .BlockSpec # static metadata
212214 dtype : Any # static metadata
213215 buffer_type : BufferType # static metadata
214216 window_ref : REF | None
215217 accum_ref : REF | None
216218 current_slot : ArrayRef | None
219+ # TODO(ramiroleal): Unused by class. Remove argument from
220+ # BufferedRef instantiations.
217221 next_slot : ArrayRef | None
218222 sem_recvs : SemaphoreTuple | None
219223 sem_sends : SemaphoreTuple | None
224+ # TODO(ramiroleal): Improve prefetch/postyeet interface to avoid
225+ # using this ref.
226+ swap : ArrayRef | None
220227
221228 def tree_flatten (self ):
222229 return (
@@ -227,6 +234,7 @@ def tree_flatten(self):
227234 self .next_slot ,
228235 self .sem_recvs ,
229236 self .sem_sends ,
237+ self .swap ,
230238 ),
231239 (self .spec , self .dtype , self .buffer_type ),
232240 )
@@ -240,14 +248,15 @@ def buffer_types() -> type[BufferType]:
240248 return BufferType
241249
242250 @classmethod
243- def create (cls , spec , dtype , buffer_type ) -> BufferedRef :
251+ def create (cls , spec , dtype , buffer_type , needs_swap_ref = True ) -> BufferedRef :
244252 """Create a BufferedRef.
245253
246254 Args:
247255 spec: pallas blockspec.
248256 dtype: dtype for buffers.
249257 buffer_type: enum indicating whether this is an input, output, or in/out
250258 accumulator buffered reference.
259+ needs_swap_ref: whether a swap slots tracker needs to be allocated.
251260
252261 Returns:
253262 Initialized BufferedRef
@@ -271,6 +280,7 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef:
271280 next_slot = None ,
272281 sem_recvs = None ,
273282 sem_sends = None ,
283+ swap = None ,
274284 )
275285 else :
276286 memory_space = SMEM if spec .memory_space == SMEM else VMEM
@@ -281,7 +291,7 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef:
281291 window_ref = memory_space ((2 ,) + block_shape , dtype ),
282292 accum_ref = accum_ref ,
283293 current_slot = SMEM ((1 ,), jnp .int32 ),
284- next_slot = SMEM (( 1 ,), jnp . int32 ) ,
294+ next_slot = None ,
285295 sem_recvs = (
286296 None
287297 if buffer_type is BufferType .OUTPUT
@@ -292,23 +302,24 @@ def create(cls, spec, dtype, buffer_type) -> BufferedRef:
292302 if buffer_type is BufferType .INPUT
293303 else SemaphoreType .DMA ((2 ,))
294304 ),
305+ swap = SMEM ((1 ,), jnp .bool ) if needs_swap_ref else None ,
295306 )
296307
297308 @classmethod
298- def input (cls , spec , dtype ):
299- return cls .create (spec , dtype , BufferType .INPUT )
309+ def input (cls , spec , dtype , needs_swap_ref = True ):
310+ return cls .create (spec , dtype , BufferType .INPUT , needs_swap_ref )
300311
301312 @classmethod
302- def output (cls , spec , dtype ):
303- return cls .create (spec , dtype , BufferType .OUTPUT )
313+ def output (cls , spec , dtype , needs_swap_ref = True ):
314+ return cls .create (spec , dtype , BufferType .OUTPUT , needs_swap_ref )
304315
305316 @classmethod
306- def accumulator (cls , spec , dtype ):
307- return cls .create (spec , dtype , BufferType .ACCUMULATOR )
317+ def accumulator (cls , spec , dtype , needs_swap_ref = True ):
318+ return cls .create (spec , dtype , BufferType .ACCUMULATOR , needs_swap_ref )
308319
309320 @classmethod
310- def input_output (cls , spec , dtype ):
311- return cls .create (spec , dtype , BufferType .INPUT_OUTPUT )
321+ def input_output (cls , spec , dtype , needs_swap_ref = True ):
322+ return cls .create (spec , dtype , BufferType .INPUT_OUTPUT , needs_swap_ref )
312323
313324 @property
314325 def block_shape (self ):
@@ -329,7 +340,7 @@ def current_ref(self):
329340 if self .memory_space == VMEM :
330341 return self .window_ref .at [buffer_slice ]
331342 else :
332- return self .window_ref .at [(self .current_slot [ 0 ] , * buffer_slice )]
343+ return self .window_ref .at [(self .current_slot_index , * buffer_slice )]
333344
334345 @property
335346 def is_input (self ):
@@ -355,6 +366,14 @@ def is_accumulator(self):
355366 def is_input_output (self ):
356367 return self .buffer_type == BufferType .INPUT_OUTPUT
357368
369+ @property
370+ def current_slot_index (self ):
371+ return self .current_slot [0 ]
372+
373+ @property
374+ def next_slot_index (self ):
375+ return lax .rem (self .current_slot_index + 1 , 2 )
376+
358377 def bind_existing_ref (self , window_ref , indices ):
359378 """For handling VMEM references, the pipeline aliases the existing ref."""
360379 if self .memory_space == VMEM :
@@ -373,12 +392,15 @@ def init_slots(self):
373392 """Initialize slot indices."""
374393 if self .memory_space == VMEM : return
375394 self .current_slot [0 ] = 0
376- self .next_slot [0 ] = 0
395+ if self .swap is not None :
396+ self .swap [0 ] = False
377397
378398 def swap_slots (self ):
379399 """Switch to the next slot."""
380400 if self .memory_space == VMEM : return
381- self .current_slot [0 ] = self .next_slot [0 ]
401+ self .current_slot [0 ] = self .next_slot_index
402+ if self .swap is not None :
403+ self .swap [0 ] = False
382404
383405 def get_dma_slice (self , src_shape , src_dtype , grid_indices ):
384406 # We need to handle blocks that might go OOB in the src array. An in bounds
@@ -441,8 +463,9 @@ def copy_in(self, src_ref, grid_indices):
441463 """Starts copy of HBM dma slice into the current slot."""
442464 assert self .is_input
443465 if self .memory_space == VMEM : return
444- next_slot = lax .rem (self .current_slot [0 ] + 1 , 2 )
445- self .next_slot [0 ] = next_slot
466+ if self .swap is not None :
467+ self .swap [0 ] = True
468+ next_slot = self .next_slot_index
446469 src_slice = self .get_dma_slice (src_ref .shape , src_ref .dtype , grid_indices )
447470 dst_slice = tuple (pl .ds (0 , s .size ) for s in src_slice )
448471 tpu_primitives .make_async_copy (
@@ -455,8 +478,9 @@ def copy_out(self, dst_ref, grid_indices):
455478 """Starts copy of HBM dma slice from the current slot."""
456479 assert self .is_output
457480 if self .memory_space == VMEM : return
458- slot = self .current_slot [0 ]
459- self .next_slot [0 ] = lax .rem (slot + 1 , 2 )
481+ if self .swap is not None :
482+ self .swap [0 ] = True
483+ slot = self .current_slot_index
460484 dst_slice = self .get_dma_slice (dst_ref .shape , dst_ref .dtype , grid_indices )
461485 src_slice = tuple (pl .ds (0 , s .size ) for s in dst_slice )
462486 tpu_primitives .make_async_copy (
@@ -471,7 +495,7 @@ def wait_in(self, src_ref, grid_indices):
471495 if self .memory_space == VMEM : return
472496 src_slice = self .get_dma_slice (src_ref .shape , src_ref .dtype , grid_indices )
473497 dst_slice = tuple (pl .ds (0 , s .size ) for s in src_slice )
474- current_slot = self .current_slot [ 0 ]
498+ current_slot = self .current_slot_index
475499 tpu_primitives .make_async_copy (
476500 src_ref .at [src_slice ], # nb: doesn't matter
477501 self .window_ref .at [current_slot ].at [
@@ -484,7 +508,8 @@ def wait_out(self, dst_ref, grid_indices):
484508 """Waits for output copy to finish."""
485509 assert self .is_output
486510 if self .memory_space == VMEM : return
487- prev_slot = lax .rem (self .current_slot [0 ] + 1 , 2 )
511+ # In a double buffer, previous slot is the same as next slot.
512+ prev_slot = self .next_slot_index
488513 dst_slice = self .get_dma_slice (dst_ref .shape , dst_ref .dtype , grid_indices )
489514 src_slice = tuple (pl .ds (0 , s .size ) for s in dst_slice )
490515 tpu_primitives .make_async_copy (
@@ -671,10 +696,7 @@ def _init_slots():
671696 def _start ():
672697 if buffered_ref .is_input :
673698 buffered_ref .copy_in (src_ref , self .indices )
674-
675- # In the prologue this makes it so we wait on the prologue copy to finish.
676- # In other iterations this is the regular swap.
677- buffered_ref .swap_slots ()
699+ buffered_ref .swap_slots ()
678700
679701 def wait_in (self , buffered_ref , src_ref , schedule = None ):
680702 if schedule is None :
@@ -780,9 +802,32 @@ def finalize(self, buffered_ref, dst_ref, schedule=None):
780802 @self ._named_scope ("ep_finalize" )
781803 def _end ():
782804 if buffered_ref .is_output :
783- buffered_ref .swap_slots () # formally correct, not actually necessary.
784805 buffered_ref .wait_out (dst_ref , self .indices )
785806
807+ def swap_slots (self , buffered_ref , hbm_ref , schedule = None ):
808+ if buffered_ref .swap is not None :
809+ swap = buffered_ref .swap [0 ]
810+ else :
811+ # If we are not using an SMEM `swap` tensor to keep track of
812+ # swaps needed, then all the copies into and out of BufferedRefs
813+ # are done by direct calls to the `copy_in` and `copy_out`
814+ # methods in the pipeline loop. To determine if the BufferedRef
815+ # needs a swap of slots, we recalculate the copy-in/copy-out
816+ # conditions.
817+ if schedule is None :
818+ schedule = _default_schedule
819+ pred_in = schedule ["copy_in" ](self , buffered_ref , hbm_ref )
820+ pred_out = schedule ["copy_out" ](self , buffered_ref , hbm_ref )
821+
822+ copied_in = pred_in & buffered_ref .is_input & ~ self .last_step
823+ copied_out = pred_out & buffered_ref .is_output
824+ swap = copied_in | copied_out
825+
826+ @pl .when (swap )
827+ @self ._named_scope ("ep_swap" )
828+ def _swap ():
829+ buffered_ref .swap_slots ()
830+
786831 # END SCHEDULE --------------------------------------------------------------
787832
788833
@@ -875,6 +920,7 @@ def make_pipeline_allocations(
875920 in_specs = None ,
876921 out_specs = None ,
877922 should_accumulate_out = False ,
923+ needs_swap_ref = True ,
878924):
879925 """Create BufferedRefs for the pipeline.
880926
@@ -887,6 +933,7 @@ def make_pipeline_allocations(
887933 out_specs: output pallas block specs
888934 should_accumulate_out: booleans to indicate which outputs should be treated
889935 as accumulators.
936+ needs_swap_ref: whether a swap slots tracker needs to be allocated.
890937
891938 Returns:
892939 A list of BufferedRefs, one corresponding to each ref specified in the
@@ -905,12 +952,12 @@ def make_pipeline_allocations(
905952 in_refs = refs [:num_in_specs ]
906953 out_refs = refs [num_in_specs :]
907954 def make_input_bref (in_spec , in_ref ):
908- return BufferedRef .input (in_spec , in_ref .dtype )
955+ return BufferedRef .input (in_spec , in_ref .dtype , needs_swap_ref )
909956 in_brefs = jax .tree .map (make_input_bref , in_specs , in_refs )
910957 def make_output_bref (out_spec , out_ref , accumulate ):
911958 if accumulate :
912- return BufferedRef .accumulator (out_spec , out_ref .dtype )
913- return BufferedRef .output (out_spec , out_ref .dtype )
959+ return BufferedRef .accumulator (out_spec , out_ref .dtype , needs_swap_ref )
960+ return BufferedRef .output (out_spec , out_ref .dtype , needs_swap_ref )
914961 out_brefs = jax .tree .map (
915962 make_output_bref , out_specs , out_refs , should_accumulate_out )
916963 return (* in_brefs , * out_brefs )
@@ -1109,6 +1156,14 @@ def pipeline(
11091156 scratches = ()
11101157 if allocations is None :
11111158 # run with inline scoped allocations
1159+
1160+ # Prefetch and postyeet are arbitrary functions that can copy
1161+ # into or out of any of the BufferedRefs. Thus, we need a ref
1162+ # for the scheduler to mark when the prefetch or postyeet
1163+ # functions perform a copy and the slots need to be
1164+ # swapped. Without prefetch and postyeet, the swapping logic can
1165+ # be performed without the need for state.
1166+ needs_swap_ref = prefetch is not None or postyeet is not None
11121167 return primitives .run_scoped (
11131168 lambda allocations : pipeline (
11141169 * refs ,
@@ -1125,7 +1180,9 @@ def pipeline(
11251180 * refs ,
11261181 in_specs = in_specs ,
11271182 out_specs = out_specs ,
1128- should_accumulate_out = should_accumulate_out ),
1183+ should_accumulate_out = should_accumulate_out ,
1184+ needs_swap_ref = needs_swap_ref ,
1185+ ),
11291186 )
11301187 if isinstance (allocations , list ):
11311188 allocations = tuple (allocations )
@@ -1184,6 +1241,8 @@ def loop_body(step, indices):
11841241 lax .cond (step == 0 ,
11851242 lambda : postyeet (* brefs , scheduler ),
11861243 lambda : None )
1244+
1245+ map_brefs (scheduler .swap_slots , brefs , refs , schedule )
11871246 map_brefs (scheduler .finalize , brefs , refs , schedule )
11881247
11891248 return _next_index (indices , grid )
0 commit comments