Skip to content

Commit 2b2a872

Browse files
authored
[Blackwell] Fallback to MMAv2 for numWarps other than 4 or 8 (#5978)
Currently we allow MMAv5 for any multiple of 4 `numWarps`. But in practice, only 4 or 8 are supported according to https://github.com/triton-lang/triton/blob/4f302822d25047a8853b12ced682aeb8b20c90f9/lib/Dialect/TritonNvidiaGPU/IR/Dialect.cpp#L102
1 parent 14d7bcc commit 2b2a872

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

lib/Analysis/Utility.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,11 @@ bool supportMMA(triton::DotOp op, int version) {
638638
return false;
639639
if (op.getType().getRank() != 2)
640640
return false;
641-
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
641+
if (numWarps != 4 && numWarps != 8) {
642+
// Currently only support numWarps 4 or 8 for TMEM load and store.
643+
return false;
644+
}
645+
if (!(retShapePerCTA[rank - 2] % 64 == 0 &&
642646
retShapePerCTA[rank - 1] % 8 == 0))
643647
return false;
644648
return true;

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,22 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
217217

218218
// -----
219219

220+
// CHECK: #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 8], instrShape = [16, 8]}>
221+
#blocked = #ttg.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [16, 1], order = [1, 0]}>
222+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [16, 1], order = [1, 0]}>
223+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [4, 4], order = [1, 0]}>
224+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 16 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
225+
// CHECK-label: mmav5_fallback_v2_num_warps
226+
tt.func public @mmav5_fallback_v2_num_warps(%a: tensor<128x64xf16, #blocked2>, %b: tensor<64x256xf16, #blocked1>, %c: tensor<128x256xf32, #blocked>) -> tensor<128x256xf32, #blocked> {
227+
%ad = ttg.convert_layout %a : tensor<128x64xf16, #blocked2> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>>
228+
%bd = ttg.convert_layout %b : tensor<64x256xf16, #blocked1> -> tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>>
229+
%d = tt.dot %ad, %bd, %c, inputPrecision = tf32 : tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<64x256xf16, #ttg.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<128x256xf32, #blocked>
230+
tt.return %d : tensor<128x256xf32, #blocked>
231+
}
232+
}
233+
234+
// -----
235+
220236
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
221237
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
222238
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>

0 commit comments

Comments
 (0)