Skip to content

Commit f8a19d1

Browse files
authored
[TritonNVIDIAGPU] Revert MMAv5 write effect on barrier (triton-lang#6484)
This is a tiny partial revert of triton-lang#6476, removing the addition `MemWrite<SharedMemory>` on the barrier operand of the MMAv5 ops. This wasn't present before. This is causing extra barriers to be inserted between MMA ops and mbarrier waits. This was causing large performance regressions across the board. However, as things stand, the compiler can't see interfering memory effects between the MMAv5 ops and barrier waits. We will probably have to model the side effects here in a different way. Technically, the MMAv5 ops don't write or read the barrier, but push a request onto the MMAv5 pipeline.
1 parent face3d2 commit f8a19d1

File tree

2 files changed

+35
-4
lines changed

2 files changed

+35
-4
lines changed

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,6 @@ static void getMMAEffects(
208208
effects.emplace_back(MemoryEffects::Read::get(), &op.getAMutable(),
209209
TensorMemory::get());
210210
}
211-
if (op.getBarrier()) {
212-
effects.emplace_back(MemoryEffects::Write::get(),
213-
op.getBarrierMutable().begin(), SharedMemory::get());
214-
}
215211

216212
effects.emplace_back(MemoryEffects::Read::get(), &op.getBMutable(),
217213
SharedMemory::get());

test/TritonNvidiaGPU/membar.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,38 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
120120
tt.return
121121
}
122122
}
123+
124+
// -----
125+
126+
#mma = #ttg.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0], instrShape = [16, 256, 32]}>
127+
#shared = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}>
128+
#shared1 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = true, elementBitWidth = 16}>
129+
#shared2 = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
130+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
131+
#smem = #ttg.shared_memory
132+
133+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
134+
135+
// CHECK-LABEL: @wait_after_mma
136+
tt.func @wait_after_mma(
137+
%a: !ttg.memdesc<128x128xf16, #shared, #smem>,
138+
%b: !ttg.memdesc<128x128xf16, #shared1, #smem>,
139+
%c: !ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
140+
%useAcc: i1,
141+
%pred: i1
142+
) {
143+
%phase = arith.constant 0 : i32
144+
%barrier = ttg.local_alloc : () -> !ttg.memdesc<1xi64, #shared2, #smem, mutable>
145+
// CHECK: ttng.tc_gen5_mma
146+
ttng.tc_gen5_mma %a, %b, %c, %useAcc, %pred, %barrier :
147+
(!ttg.memdesc<128x128xf16, #shared, #smem>,
148+
!ttg.memdesc<128x128xf16, #shared1, #smem>,
149+
!ttg.memdesc<128x128xf32, #tmem, #ttng.tensor_memory, mutable>,
150+
i1, i1,
151+
!ttg.memdesc<1xi64, #shared2, #smem, mutable>) -> ()
152+
// CHECK-NEXT: ttng.wait_barrier
153+
ttng.wait_barrier %barrier, %phase : !ttg.memdesc<1xi64, #shared2, #smem, mutable>
154+
tt.return
155+
}
156+
157+
}

0 commit comments

Comments
 (0)