Skip to content

Commit b3a2c53

Browse files
[NFC] Fix linter errors in pipeline file
PiperOrigin-RevId: 741644574
1 parent 47876bb commit b3a2c53

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,8 @@ class BufferedRef:
213213
spec: pl.BlockSpec # static metadata
214214
dtype: Any # static metadata
215215
buffer_type: BufferType # static metadata
216-
window_ref: REF | None
217-
accum_ref: REF | None
216+
window_ref: ArrayRef | None
217+
accum_ref: ArrayRef | None
218218
current_slot: ArrayRef | None
219219
# TODO(ramiroleal): Unused by class. Remove argument from
220220
# BufferedRef instantiations.
@@ -337,6 +337,7 @@ def memory_space(self):
337337
def current_ref(self):
338338
buffer_slice = tuple(
339339
0 if x is None else slice(None) for x in self.block_shape)
340+
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
340341
if self.memory_space == VMEM:
341342
return self.window_ref.at[buffer_slice]
342343
else:
@@ -368,10 +369,12 @@ def is_input_output(self):
368369

369370
@property
370371
def current_slot_index(self):
372+
"""Index in double buffer corresponding to the current slot."""
371373
return self.current_slot[0]
372374

373375
@property
374376
def next_slot_index(self):
377+
"""Index in double buffer corresponding to the next slot."""
375378
return lax.rem(self.current_slot_index + 1, 2)
376379

377380
def bind_existing_ref(self, window_ref, indices):
@@ -463,28 +466,32 @@ def copy_in(self, src_ref, grid_indices):
463466
"""Starts copy of HBM dma slice into the current slot."""
464467
assert self.is_input
465468
if self.memory_space == VMEM: return
469+
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
470+
assert self.sem_recvs is not None
466471
if self.swap is not None:
467472
self.swap[0] = True
468473
next_slot = self.next_slot_index
469474
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
470475
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
471476
tpu_primitives.make_async_copy(
472477
src_ref.at[src_slice],
473-
self.window_ref.at[next_slot].at[dst_slice],
478+
self.window_ref.at[(next_slot, *dst_slice)],
474479
self.sem_recvs.at[next_slot],
475480
).start()
476481

477482
def copy_out(self, dst_ref, grid_indices):
478483
"""Starts copy of HBM dma slice from the current slot."""
479484
assert self.is_output
480485
if self.memory_space == VMEM: return
486+
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
487+
assert self.sem_sends is not None
481488
if self.swap is not None:
482489
self.swap[0] = True
483490
slot = self.current_slot_index
484491
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
485492
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
486493
tpu_primitives.make_async_copy(
487-
self.window_ref.at[slot].at[src_slice],
494+
self.window_ref.at[(slot, *src_slice)],
488495
dst_ref.at[dst_slice],
489496
self.sem_sends.at[slot],
490497
).start()
@@ -493,13 +500,15 @@ def wait_in(self, src_ref, grid_indices):
493500
"""Waits for input copy to finish."""
494501
assert self.is_input
495502
if self.memory_space == VMEM: return
503+
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
504+
assert self.sem_recvs is not None
496505
src_slice = self.get_dma_slice(src_ref.shape, src_ref.dtype, grid_indices)
497506
dst_slice = tuple(pl.ds(0, s.size) for s in src_slice)
498507
current_slot = self.current_slot_index
499508
tpu_primitives.make_async_copy(
500509
src_ref.at[src_slice], # nb: doesn't matter
501-
self.window_ref.at[current_slot].at[
502-
dst_slice
510+
self.window_ref.at[
511+
(current_slot, *dst_slice)
503512
], # only dst shape is important
504513
self.sem_recvs.at[current_slot],
505514
).wait()
@@ -508,12 +517,14 @@ def wait_out(self, dst_ref, grid_indices):
508517
"""Waits for output copy to finish."""
509518
assert self.is_output
510519
if self.memory_space == VMEM: return
520+
assert not (self.window_ref is None or isinstance(self.window_ref, REF))
521+
assert self.sem_sends is not None
511522
# In a double buffer, previous slot is the same as next slot.
512523
prev_slot = self.next_slot_index
513524
dst_slice = self.get_dma_slice(dst_ref.shape, dst_ref.dtype, grid_indices)
514525
src_slice = tuple(pl.ds(0, s.size) for s in dst_slice)
515526
tpu_primitives.make_async_copy(
516-
self.window_ref.at[prev_slot].at[src_slice], # nb: doesn't matter
527+
self.window_ref.at[(prev_slot, *src_slice)], # nb: doesn't matter
517528
dst_ref.at[dst_slice], # only dst shape is important
518529
self.sem_sends.at[prev_slot],
519530
).wait()
@@ -533,16 +544,18 @@ def set_accumulator(self, init=False):
533544
"""Set accumulator or zero it out to initialize."""
534545
assert self.is_accumulator
535546
if self.accum_ref is not None:
547+
accum_dtype = self.accum_ref.dtype
536548
def _init():
537549
self.accum_ref[...] = jnp.zeros_like(self.accum_ref[...])
538550
def _set():
539-
self.accum_ref[...] = self.current_ref[...].astype(self.accum_ref.dtype)
551+
self.accum_ref[...] = self.current_ref[...].astype(accum_dtype)
540552
lax.cond(init, _init, _set)
541553

542554
def accumulate(self):
543555
"""Add into the current slot."""
544556
assert self.is_accumulator
545557
if self.accum_ref is not None:
558+
assert self.window_ref is not None
546559
accum_dtype = jnp.float32
547560
if self.window_ref.dtype == jnp.int32:
548561
accum_dtype = jnp.int32

0 commit comments

Comments
 (0)