Skip to content

Commit 0d2a7c8

Browse files
authored
[BACKEND] Don't use mmav5 with num warps < 4 (#7928)
1 parent 9d5ca6f commit 0d2a7c8

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,8 @@ class ScaledBlockedToMMAv5
645645
auto CTALayout = getCTALayout(oldRetType.getEncoding());
646646
if ((computeCapability) / 10 != 10)
647647
return failure();
648+
if (numWarps != 4 && numWarps != 8)
649+
return failure();
648650
if (retShapePerCTA[0] < 128 || retShapePerCTA[1] < 8)
649651
return failure();
650652
Location loc = dotOp.getLoc();

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,23 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
371371

372372
// -----
373373

374+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 2], order = [1, 0]}>
375+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [2, 1], order = [1, 0]}>
376+
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [2, 1], order = [1, 0]}>
377+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, ttg.target = "cuda:100", "ttg.threads-per-warp" = 32 : i32} {
378+
// Make sure we fall back to mmav2 when num warps < 4
379+
// CHECK-LABEL: block_scaled_2_warps
380+
// CHECK: tt.dot
381+
// CHECK: tt.return
382+
tt.func public @block_scaled_2_warps(%a: tensor<128x64xf8E4M3FN, #blocked2>, %scale_a: tensor<128x2xi8, #blocked1>, %b: tensor<64x128xf8E4M3FN, #blocked>, %scale_b: tensor<128x2xi8, #blocked1>) -> tensor<128x128xf32, #blocked> {
383+
%cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked>
384+
%d = tt.dot_scaled %a scale %scale_a, %b scale %scale_b, %cst lhs = e4m3 rhs = e4m3 {fastMath = false} : tensor<128x64xf8E4M3FN, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked>, tensor<128x2xi8, #blocked1> -> tensor<128x128xf32, #blocked>
385+
tt.return %d : tensor<128x128xf32, #blocked>
386+
}
387+
}
388+
389+
// -----
390+
374391
// Verify that dot_scaled (mxfp4 x {bf16,fp8}) decomposes to mmav3 if it's bf16, otherwise it fallsback to mmav2
375392
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0]}>
376393
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 2], warpsPerCTA = [4, 1], order = [1, 0]}>

0 commit comments

Comments
 (0)