Skip to content

Commit c965ffb

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic GPU] Remove expect_wait from Barrier.wait
It looks like LLVM already moves the wait loops to the end of the program, so the whole optimization is no longer necessary and only adds unnecessary operations. PiperOrigin-RevId: 703052393
1 parent 7214a3a commit c965ffb

File tree

1 file changed

+5
-21
lines changed

1 file changed

+5
-21
lines changed

jax/experimental/mosaic/gpu/utils.py

Lines changed: 5 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -673,33 +673,17 @@ def __getitem__(self, offset: ir.Value | int) -> "BarrierRef":
673673
1,
674674
)
675675

676-
def wait_parity(self, parity, expect_wait=False):
677-
i1 = ir.IntegerType.get_signless(1)
676+
def wait_parity(self, parity):
678677
i32 = ir.IntegerType.get_signless(32)
679-
ticks = c(10000000, i32)
680-
address = self.get_ptr()
678+
ticks = arith.constant(i32, 10000000)
681679
parity = arith.extui(i32, parity)
682-
if expect_wait:
683-
nvvm.mbarrier_try_wait_parity_shared(address, parity, ticks)
684-
return
685-
barrier_ready = llvm.inline_asm(
686-
i1,
687-
[address, parity],
688-
"mbarrier.test_wait.parity.shared.b64 $0, [$1], $2;",
689-
"=b,l,r",
690-
has_side_effects=True,
691-
)
692-
should_wait = arith.xori(barrier_ready, c(1, i1))
693-
should_wait = llvm.intr_expect(should_wait, c(0, i1))
694-
with ir.InsertionPoint(scf.IfOp(should_wait).then_block):
695-
nvvm.mbarrier_try_wait_parity_shared(address, parity, ticks)
696-
scf.yield_([])
680+
nvvm.mbarrier_try_wait_parity_shared(self.get_ptr(), parity, ticks)
697681

698-
def wait(self, expect_wait=False):
682+
def wait(self):
699683
parities = memref.load(self.phases, [])
700684
parity, new_parities = self.update_parities(parities)
701685
memref.store(new_parities, self.phases, [])
702-
self.wait_parity(parity, expect_wait=expect_wait)
686+
self.wait_parity(parity)
703687

704688
def update_parities(self, parities: ir.Value) -> tuple[ir.Value, ir.Value]:
705689
i32 = ir.IntegerType.get_signless(32)

0 commit comments

Comments
 (0)