Skip to content

Commit 92b6e26

Browse files
[AMD] Add a block ping-poing scheduling pass (#5018)
This change introduces a new pass, `tritonamdgpu-block-pingpong`. Main target is the GEMM kernel and the ideal case it tries to generate is that having two warp run in parallel on one SIMD, alternately execute a section of `mfma` instruction and a section of `memory` instructions so that GPU can make `mfma` busy while hiding the latency of `memory` instructions. Right now behind an env var `TRITON_HIP_USE_BLOCK_PINGPONG=1` --------- Co-authored-by: Lei Zhang <[email protected]>
1 parent 125c165 commit 92b6e26

File tree

12 files changed

+739
-2
lines changed

12 files changed

+739
-2
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6262
mlir::registerTritonAMDGPUAccelerateMatmul();
6363
mlir::registerTritonAMDGPUOptimizeEpilogue();
6464
mlir::registerTritonAMDGPUReorderInstructions();
65+
mlir::registerTritonAMDGPUBlockPingpong();
6566
mlir::registerTritonAMDGPUStreamPipeline();
6667
mlir::registerTritonAMDGPUCanonicalizePointers();
6768
mlir::registerTritonAMDGPUConvertToBufferOps();

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
2929
"TRITON_DISABLE_RESHAPE_ENCODING_INFERENCE",
3030
"TRITON_ENABLE_LLVM_DEBUG",
3131
"TRITON_HIP_STREAM_PREFETCH",
32+
"TRITON_HIP_USE_BLOCK_PINGPONG",
3233
"TRITON_LLVM_DEBUG_ONLY",
3334
"USE_IR_LOC",
3435
"NVPTX_ENABLE_DUMP",

test/TritonGPU/amd/amd-block-pingpong.mlir

Lines changed: 257 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// RUN: triton-opt %s --convert-triton-amdgpu-to-llvm='arch=gfx942' | FileCheck %s
2+
3+
// CHECK-LABEL: llvm.func @sink_setprio
4+
// CHECK: rocdl.mfma
5+
// CHECK-NOT: rocdl.mfma
6+
// CHECK: rocdl.s.setprio 1
7+
// CHECK-COUNT-15: rocdl.mfma
8+
// CHECK-NOT: rocdl.mfma
9+
// CHECK: rocdl.s.setprio 0
10+
11+
#blocked = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 8], order = [0, 1]}>
12+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [8, 1], order = [1, 0]}>
13+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 4], instrShape = [16, 16], isTransposed = true}>
14+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} {
15+
tt.func public @sink_setprio(
16+
%arg0: tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>,
17+
%arg1: tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>>) {
18+
%cst_0 = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #mma>
19+
rocdl.s.setprio 1
20+
%dot = tt.dot %arg0, %arg1, %cst_0 :
21+
tensor<64x128xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<128x64xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<64x64xf32, #mma>
22+
rocdl.s.setprio 0
23+
tt.return
24+
}
25+
}

third_party/amd/backend/compiler.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,9 @@ def make_ttgir(mod, metadata, options):
255255
passes.ttgpuir.add_reduce_data_duplication(pm)
256256
if amd.has_matrix_core_feature(options.arch):
257257
amd.passes.ttgpuir.add_reorder_instructions(pm)
258+
use_block_pingpong = os.getenv("TRITON_HIP_USE_BLOCK_PINGPONG", "0") == "1"
259+
if use_block_pingpong and options.num_stages == 2:
260+
amd.passes.ttgpuir.add_block_pingpong(pm)
258261

259262
if use_buffer_ops:
260263
amd.passes.ttgpuir.add_canonicalize_pointers(pm)

third_party/amd/include/TritonAMDGPUTransforms/Passes.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#ifndef TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_PASSES_H_
22
#define TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTRANSFORMS_PASSES_H_
33

4+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
45
#include "mlir/Pass/Pass.h"
56
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
67
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
@@ -27,6 +28,8 @@ std::unique_ptr<Pass> createTritonAMDGPUCanonicalizePointersPass();
2728

2829
std::unique_ptr<Pass> createTritonAMDGPUConvertToBufferOpsPass();
2930

31+
std::unique_ptr<Pass> createTritonAMDGPUBlockPingpongPass();
32+
3033
/// Generate the code for registering passes.
3134
#define GEN_PASS_REGISTRATION
3235
#include "TritonAMDGPUTransforms/Passes.h.inc"

third_party/amd/include/TritonAMDGPUTransforms/Passes.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,4 +124,19 @@ def TritonAMDGPUConvertToBufferOps : Pass<"tritonamdgpu-convert-buffer-ops", "ml
124124
let dependentDialects = ["mlir::triton::amdgpu::TritonAMDGPUDialect"];
125125
}
126126

127+
def TritonAMDGPUBlockPingpong: Pass<"tritonamdgpu-block-pingpong", "mlir::ModuleOp"> {
128+
let summary = "Interleaving instructions from two warps on the same SIMD to better utilize matrix core";
129+
130+
let description = [{
131+
This pass reorder instructions to interleave instructions from two warps on the same SIMD unit.
132+
We call this a ping-pong scheduling pattern, where two warps run concurrently in the synchronized fashion
133+
This block ping-pong pattern could be beneficial under few conditions including
134+
occupancy and number of warps.
135+
}];
136+
137+
let constructor = "mlir::createTritonAMDGPUBlockPingpongPass()";
138+
139+
let dependentDialects = ["mlir::ROCDL::ROCDLDialect, mlir::triton::amdgpu::TritonAMDGPUDialect"];
140+
}
141+
127142
#endif

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM/MFMA.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,9 @@ struct DotOpMFMAConversionHelper {
164164

165165
// Conduct the Dot conversion.
166166
LogicalResult convertDot(DotOp op, DotOpAdaptor adaptor) const {
167+
// Check if this dot has come with priority set by setprio.
168+
auto setPrioOp = dyn_cast_or_null<ROCDL::SetPrioOp>(op->getPrevNode());
169+
167170
auto warpsPerCTA = mfmaLayout.getWarpsPerCTA();
168171
auto mDim = mfmaLayout.getMDim();
169172
auto nDim = mfmaLayout.getNDim();
@@ -226,6 +229,12 @@ struct DotOpMFMAConversionHelper {
226229
getNumSubmatrices(aTensorTy.getElementType(), mDim, nDim);
227230
auto elemsPerVec = mDim * nDim * subBlocks / warpSize;
228231

232+
Value firstMfma;
233+
auto setFirstMfma = [&](Value mfma) {
234+
if (!firstMfma)
235+
firstMfma = mfma;
236+
};
237+
229238
auto vecTy = vec_ty(dstElemTy, elemsPerVec);
230239
for (int b = 0; b < numRepB; ++b) {
231240
for (int m = 0; m < numRepM; ++m) {
@@ -240,13 +249,15 @@ struct DotOpMFMAConversionHelper {
240249
}
241250
acc = zeroAuxiliarBlocks(subBlocks, acc);
242251
for (int k = 0; k < numRepK; k++) {
243-
for (int kPack = 0; kPack < kWidth / kBase; ++kPack)
252+
for (int kPack = 0; kPack < kWidth / kBase; ++kPack) {
244253
acc =
245254
mfmaLayout.getIsTransposed()
246255
? generateMFMAOp(mfmaInsnName, operandB[kPack][{b, n, k}],
247256
operandA[kPack][{b, m, k}], acc)
248257
: generateMFMAOp(mfmaInsnName, operandA[kPack][{b, m, k}],
249258
operandB[kPack][{b, n, k}], acc);
259+
setFirstMfma(acc);
260+
}
250261
}
251262
acc = reduceSubBlocks(subBlocks, acc);
252263
for (unsigned v = 0; v < elemsPerVec; ++v) {
@@ -257,6 +268,16 @@ struct DotOpMFMAConversionHelper {
257268
}
258269
}
259270
}
271+
272+
// Originally, setprio (high) is set to the high-level dot op. After dot is
273+
// being lowered to the series of mfma operations, it should be moved next
274+
// to the first mfma leaving the first mfma staying at the low priority. In
275+
// this way, incoming warp can be effectively waiting on the first mfma
276+
// instruction (low priority) while the other warp is executing mfma with
277+
// high priority. Otherwise, incoming warp can break the cluster.
278+
if (setPrioOp && firstMfma)
279+
setPrioOp->moveAfter(firstMfma.getDefiningOp());
280+
260281
// replace with new packed result
261282
Type structTy = LLVM::LLVMStructType::getLiteral(
262283
ctx, SmallVector<Type>(fc.size(), dstElemTy));

third_party/amd/lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ struct CondBarrierOpConversion
4747
rewriter.setInsertionPointToStart(trueBlock);
4848
rewriter.create<ROCDL::SBarrierOp>(loc);
4949
rewriter.create<LLVM::BrOp>(loc, afterCondBarBlock);
50-
5150
rewriter.eraseOp(op);
5251
return success();
5352
}

0 commit comments

Comments
 (0)