Skip to content

Commit dc5286d

Browse files
lezcanoguacamoleo
authored andcommitted
[BACKEND] Simplify and comment warp allocation logic in mmav2 (triton-lang#5041)
It's not entirely clear to me whether the previous logic was equivalent to this one, as it was rather obtuse. I think the new one is optimal but I'm happy to run benchmarks to make sure we don't regress.
1 parent 1219b01 commit dc5286d

File tree

2 files changed

+29
-23
lines changed

2 files changed

+29
-23
lines changed

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
88
#include "triton/Analysis/Utility.h"
99
#include "triton/Dialect/Triton/IR/Dialect.h"
10+
#include "triton/Dialect/Triton/IR/Utility.h"
1011
#include "triton/Dialect/TritonGPU/IR/Attributes.h"
1112
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1213
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
@@ -77,28 +78,33 @@ SmallVector<unsigned> warpsPerTileV2(DotOp dotOp, const ArrayRef<int64_t> shape,
7778
}
7879
}
7980

80-
SmallVector<unsigned> ret(rank, 1);
81-
SmallVector<int64_t> shapePerWarp(rank, 1);
82-
shapePerWarp[rank - 1] = 8;
83-
shapePerWarp[rank - 2] = 16;
84-
// TODO (@daadaada): double-check.
85-
// original logic in
86-
// https://github.com/triton-lang/triton/blob/master/lib/codegen/analysis/layout.cc#L252
87-
// seems buggy for shape = [32, 16] ?
88-
do {
89-
if (ret[0] * ret[1] >= numWarps)
90-
break;
91-
if (shape[0] / shapePerWarp[0] / ret[0] >=
92-
shape[1] / (shapePerWarp[1] * 2) / ret[1]) {
93-
if (ret[0] < shape[0] / shapePerWarp[0]) {
94-
ret[0] *= 2;
95-
} else
96-
ret[1] *= 2;
81+
assert(rank == 2);
82+
SmallVector<int64_t> shapePerWarp = {16, 8};
83+
SmallVector<int64_t> warps = {1, 1};
84+
// Compute repM and repN
85+
SmallVector<int64_t> reps = {ceil(shape[0], shapePerWarp[0]),
86+
ceil(shape[1], shapePerWarp[1])};
87+
// The formula for the number of registers given the reps is
88+
// repM * 4 * repK + repN * 2 * repK + regsC
89+
// where regsC = repM * repN * 4, which does not depend on the warp shape
90+
//
91+
// As such, to minimize the register pressure, we need to balance
92+
// repM and repN. We then untie towards M, as the lhs tile has 4 elements,
93+
// and the rhs tile has just 2.
94+
while (product(warps) < numWarps) {
95+
if (reps[0] >= reps[1]) {
96+
warps[0] *= 2;
97+
// Too many warps for this mma (repM == repN == 1).
98+
// We allocate the remainin warps to the left (arbitrary choice)
99+
if (reps[0] != 1) {
100+
reps[0] /= 2;
101+
}
97102
} else {
98-
ret[1] *= 2;
103+
warps[1] *= 2;
104+
reps[1] /= 2;
99105
}
100-
} while (true);
101-
return ret;
106+
}
107+
return {(unsigned)warps[0], (unsigned)warps[1]};
102108
}
103109

104110
SmallVector<unsigned, 2>

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :
7373

7474
// -----
7575

76-
// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}>
76+
// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 8]}>
7777
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
7878
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
7979
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
@@ -93,7 +93,7 @@ module attributes {"triton_gpu.target" = "cuda:89", "triton_gpu.num-ctas" = 1 :
9393

9494
// -----
9595

96-
// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 4], instrShape = [16, 8]}>
96+
// CHECK-DAG: #[[MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [16, 8]}>
9797
// CHECK-DAG: #[[MMA1:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 1, 1], instrShape = [1, 16, 8]}>
9898

9999
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [1, 2, 16], warpsPerCTA = [1, 4, 1], order = [2, 1, 0]}>
@@ -148,7 +148,7 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
148148
// -----
149149

150150
// Verify that we use mmav2 when the k dim is too small for mmav3.
151-
// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 4], instrShape = [16, 8]}>
151+
// CHECK: #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 8], instrShape = [16, 8]}>
152152
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [32, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
153153
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
154154
// CHECK-LABEL: small_k_size

0 commit comments

Comments
 (0)