Skip to content

Commit e74f027

Browse files
bertmaherJokeren
andauthored
[BACKEND] Fix and document logic for creating warp shapes in MMAv3 (triton-lang#5441)
Cherry pick of triton-lang#5277 for 3.2 release Co-authored-by: Keren Zhou <[email protected]>
1 parent 8af9311 commit e74f027

File tree

2 files changed

+33
-2
lines changed

2 files changed

+33
-2
lines changed

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,12 @@ warpsPerTileV3(DotOp dotOp, const ArrayRef<int64_t> shape, int numWarps,
106106
const SmallVector<unsigned, 3> &instrShape) {
107107
SetVector<Operation *> slices;
108108
mlir::getForwardSlice(dotOp.getResult(), &slices);
109-
if (llvm::find_if(slices, [](Operation *op) { return isa<DotOp>(op); }) !=
110-
slices.end())
109+
// Contains a chained dot. We prefer to assign warps to one axis
110+
// to facilitate use cases like flash attention, allowing reductions within
111+
// the same warp.
112+
if (llvm::find_if(slices, [](Operation *op) {
113+
return op->hasTrait<OpTrait::DotLike>();
114+
}) != slices.end())
111115
return {(unsigned)numWarps, 1};
112116

113117
// For MMAv3, the smallest indivisible unit of warp shape is (4, 1).

test/TritonGPU/accelerate-matmul.mlir

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,33 @@ module attributes {"triton_gpu.target" = "cuda:80", "triton_gpu.num-ctas" = 1 :
7373

7474
// -----
7575

76+
// CHECK: #mma = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [16, 32, 16]}>
77+
// CHECK: #mma1 = #triton_gpu.nvidia_mma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 64, 16]}>
78+
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
79+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>
80+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
81+
module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 32 : i32} {
82+
// CHECK-LABEL: chained_dot
83+
tt.func public @chained_dot_wgmma(
84+
%arg0: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>>,
85+
%arg1: tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>>,
86+
%arg2: tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>>) -> tensor<64x128xf32, #blocked1> {
87+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked>
88+
%cst_1 = arith.constant dense<0.000000e+00> : tensor<64x128xf32, #blocked1>
89+
// CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<64x64xf32, #mma>
90+
%d = tt.dot %arg0, %arg1, %cst_0 :
91+
tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked}>> * tensor<128x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked}>> -> tensor<64x64xf32, #blocked>
92+
%t = arith.truncf %d : tensor<64x64xf32, #blocked> to tensor<64x64xf16, #blocked>
93+
%c = triton_gpu.convert_layout %t : tensor<64x64xf16, #blocked> -> tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>>
94+
// CHECK: triton_nvidia_gpu.warp_group_dot {{.*}} -> tensor<64x128xf32, #mma1>
95+
%r = tt.dot %c, %arg2, %cst_1 :
96+
tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #blocked1}>> * tensor<64x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #blocked1}>> -> tensor<64x128xf32, #blocked1>
97+
tt.return %r : tensor<64x128xf32, #blocked1>
98+
}
99+
}
100+
101+
// -----
102+
76103
// CHECK: #[[$MMA:.+]] = #triton_gpu.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [4, 2], instrShape = [16, 8]}>
77104
#blocked = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [2, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
78105
#blocked1 = #triton_gpu.blocked<{sizePerThread = [4, 4], threadsPerWarp = [1, 32], warpsPerCTA = [8, 1], order = [1, 0]}>

0 commit comments

Comments
 (0)