@@ -213,8 +213,8 @@ class BufferedRef:
213213 spec : pl .BlockSpec # static metadata
214214 dtype : Any # static metadata
215215 buffer_type : BufferType # static metadata
216- window_ref : REF | None
217- accum_ref : REF | None
216+ window_ref : ArrayRef | None
217+ accum_ref : ArrayRef | None
218218 current_slot : ArrayRef | None
219219 # TODO(ramiroleal): Unused by class. Remove argument from
220220 # BufferedRef instantiations.
@@ -337,6 +337,7 @@ def memory_space(self):
337337 def current_ref (self ):
338338 buffer_slice = tuple (
339339 0 if x is None else slice (None ) for x in self .block_shape )
340+ assert not (self .window_ref is None or isinstance (self .window_ref , REF ))
340341 if self .memory_space == VMEM :
341342 return self .window_ref .at [buffer_slice ]
342343 else :
@@ -368,10 +369,12 @@ def is_input_output(self):
368369
369370 @property
370371 def current_slot_index (self ):
372+ """Index in double buffer corresponding to the current slot."""
371373 return self .current_slot [0 ]
372374
373375 @property
374376 def next_slot_index (self ):
377+ """Index in double buffer corresponding to the next slot."""
375378 return lax .rem (self .current_slot_index + 1 , 2 )
376379
377380 def bind_existing_ref (self , window_ref , indices ):
@@ -463,28 +466,32 @@ def copy_in(self, src_ref, grid_indices):
463466 """Starts copy of HBM dma slice into the current slot."""
464467 assert self .is_input
465468 if self .memory_space == VMEM : return
469+ assert not (self .window_ref is None or isinstance (self .window_ref , REF ))
470+ assert self .sem_recvs is not None
466471 if self .swap is not None :
467472 self .swap [0 ] = True
468473 next_slot = self .next_slot_index
469474 src_slice = self .get_dma_slice (src_ref .shape , src_ref .dtype , grid_indices )
470475 dst_slice = tuple (pl .ds (0 , s .size ) for s in src_slice )
471476 tpu_primitives .make_async_copy (
472477 src_ref .at [src_slice ],
473- self .window_ref .at [next_slot ]. at [ dst_slice ],
478+ self .window_ref .at [( next_slot , * dst_slice ) ],
474479 self .sem_recvs .at [next_slot ],
475480 ).start ()
476481
477482 def copy_out (self , dst_ref , grid_indices ):
478483 """Starts copy of HBM dma slice from the current slot."""
479484 assert self .is_output
480485 if self .memory_space == VMEM : return
486+ assert not (self .window_ref is None or isinstance (self .window_ref , REF ))
487+ assert self .sem_sends is not None
481488 if self .swap is not None :
482489 self .swap [0 ] = True
483490 slot = self .current_slot_index
484491 dst_slice = self .get_dma_slice (dst_ref .shape , dst_ref .dtype , grid_indices )
485492 src_slice = tuple (pl .ds (0 , s .size ) for s in dst_slice )
486493 tpu_primitives .make_async_copy (
487- self .window_ref .at [slot ]. at [ src_slice ],
494+ self .window_ref .at [( slot , * src_slice ) ],
488495 dst_ref .at [dst_slice ],
489496 self .sem_sends .at [slot ],
490497 ).start ()
@@ -493,13 +500,15 @@ def wait_in(self, src_ref, grid_indices):
493500 """Waits for input copy to finish."""
494501 assert self .is_input
495502 if self .memory_space == VMEM : return
503+ assert not (self .window_ref is None or isinstance (self .window_ref , REF ))
504+ assert self .sem_recvs is not None
496505 src_slice = self .get_dma_slice (src_ref .shape , src_ref .dtype , grid_indices )
497506 dst_slice = tuple (pl .ds (0 , s .size ) for s in src_slice )
498507 current_slot = self .current_slot_index
499508 tpu_primitives .make_async_copy (
500509 src_ref .at [src_slice ], # nb: doesn't matter
501- self .window_ref .at [current_slot ]. at [
502- dst_slice
510+ self .window_ref .at [
511+ ( current_slot , * dst_slice )
503512 ], # only dst shape is important
504513 self .sem_recvs .at [current_slot ],
505514 ).wait ()
@@ -508,12 +517,14 @@ def wait_out(self, dst_ref, grid_indices):
508517 """Waits for output copy to finish."""
509518 assert self .is_output
510519 if self .memory_space == VMEM : return
520+ assert not (self .window_ref is None or isinstance (self .window_ref , REF ))
521+ assert self .sem_sends is not None
511522 # In a double buffer, previous slot is the same as next slot.
512523 prev_slot = self .next_slot_index
513524 dst_slice = self .get_dma_slice (dst_ref .shape , dst_ref .dtype , grid_indices )
514525 src_slice = tuple (pl .ds (0 , s .size ) for s in dst_slice )
515526 tpu_primitives .make_async_copy (
516- self .window_ref .at [prev_slot ]. at [ src_slice ], # nb: doesn't matter
527+ self .window_ref .at [( prev_slot , * src_slice ) ], # nb: doesn't matter
517528 dst_ref .at [dst_slice ], # only dst shape is important
518529 self .sem_sends .at [prev_slot ],
519530 ).wait ()
@@ -533,16 +544,18 @@ def set_accumulator(self, init=False):
533544 """Set accumulator or zero it out to initialize."""
534545 assert self .is_accumulator
535546 if self .accum_ref is not None :
547+ accum_dtype = self .accum_ref .dtype
536548 def _init ():
537549 self .accum_ref [...] = jnp .zeros_like (self .accum_ref [...])
538550 def _set ():
539- self .accum_ref [...] = self .current_ref [...].astype (self . accum_ref . dtype )
551+ self .accum_ref [...] = self .current_ref [...].astype (accum_dtype )
540552 lax .cond (init , _init , _set )
541553
542554 def accumulate (self ):
543555 """Add into the current slot."""
544556 assert self .is_accumulator
545557 if self .accum_ref is not None :
558+ assert self .window_ref is not None
546559 accum_dtype = jnp .float32
547560 if self .window_ref .dtype == jnp .int32 :
548561 accum_dtype = jnp .int32
0 commit comments