Skip to content

Commit 8dfa7be

Browse files
authored
[AMD] Improve matmul detection in reorder instructions pass (#5393)
Previously the matmul problem checks whether there is a for loop with a single dot in a function. This doesn't work well for nested loops used for example in persistent matmul kernels. The matmul problem check is updated to consider nested for loops that contain a single tl.dot operation with at least two loads. Then, the `scheduleGlobalLoadLocalStore` transformation is applied to the whole function if the whole function is just a matmul problem. Otherwise it applies to each leaf for loop with limited scope. Also now we ensure it captures both the loop body and global loads that have been peeled out into a loop prologue by the pipeliner.
1 parent 1f8966b commit 8dfa7be

File tree

2 files changed

+147
-45
lines changed

2 files changed

+147
-45
lines changed

test/TritonGPU/amd/amd-sched-2nd-load.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,48 @@ module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32}
6161
#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
6262
#smem = #ttg.shared_memory
6363

64+
// Should apply: tile size 256x256x128 with nested single dot
65+
// CHECK-LABEL: nested_sink_2nd_load_256x256x128
66+
// CHECK: %[[tileA:.*]] = tt.load
67+
// CHECK-NEXT: local_load
68+
// CHECK-NEXT: local_load
69+
// CHECK-NEXT: %[[tileB:.*]] = tt.load
70+
// CHECK-NEXT: tt.dot
71+
// CHECK-NEXT: ttg.local_store %[[tileA]]
72+
// CHECK-NEXT: ttg.local_store %[[tileB]]
73+
module attributes {"ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 64 : i32} {
74+
tt.func public @nested_sink_2nd_load_256x256x128(%A_ptr: tensor<256x128x!tt.ptr<f16>, #blocked>, %B_ptr: tensor<128x256x!tt.ptr<f16>, #blocked1>, %C_ptr: tensor<256x256x!tt.ptr<f32>, #mma>, %A_LDS: !ttg.memdesc<256x128xf16, #shared, #smem, mutable>, %B_LDS: !ttg.memdesc<128x256xf16, #shared1, #smem, mutable>) {
75+
%c0 = arith.constant 0 : i32
76+
%c1 = arith.constant 1 : i32
77+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
78+
scf.for %arg2 = %c0 to %c1 step %c1 : i32 {
79+
%0:1 = scf.for %arg0 = %c0 to %c1 step %c1 iter_args(%arg1 = %cst) -> (tensor<256x256xf32, #mma>) : i32 {
80+
%4 = tt.load %A_ptr : tensor<256x128x!tt.ptr<f16>, #blocked>
81+
%1 = ttg.local_load %A_LDS : !ttg.memdesc<256x128xf16, #shared, #smem, mutable> -> tensor<256x128xf16, #dotOp0>
82+
%5 = tt.load %B_ptr : tensor<128x256x!tt.ptr<f16>, #blocked1>
83+
%2 = ttg.local_load %B_LDS : !ttg.memdesc<128x256xf16, #shared1, #smem, mutable> -> tensor<128x256xf16, #dotOp1>
84+
%3 = tt.dot %1, %2, %arg1 : tensor<256x128xf16, #dotOp0> * tensor<128x256xf16, #dotOp1> -> tensor<256x256xf32, #mma>
85+
ttg.local_store %4, %A_LDS : tensor<256x128xf16, #blocked> -> !ttg.memdesc<256x128xf16, #shared, #smem, mutable>
86+
ttg.local_store %5, %B_LDS : tensor<128x256xf16, #blocked1> -> !ttg.memdesc<128x256xf16, #shared1, #smem, mutable>
87+
scf.yield %3 : tensor<256x256xf32, #mma>
88+
}
89+
tt.store %C_ptr, %0#0: tensor<256x256x!tt.ptr<f32>, #mma>
90+
}
91+
tt.return
92+
}
93+
}
94+
95+
// -----
96+
97+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
98+
#blocked1 = #ttg.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
99+
#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
100+
#shared = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
101+
#shared1 = #ttg.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
102+
#dotOp0 = #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
103+
#dotOp1 = #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
104+
#smem = #ttg.shared_memory
105+
64106
// Should apply: tile size 256x256x64 with single dot
65107
// CHECK-LABEL: sink_2nd_load_256x256x64
66108
// CHECK: %[[tileA:.*]] = tt.load

third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp

Lines changed: 105 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "mlir/IR/BuiltinAttributes.h"
44
#include "mlir/IR/Dominance.h"
55
#include "mlir/IR/Verifier.h"
6-
#include "mlir/Pass/Pass.h"
76
#include "mlir/Pass/PassManager.h"
87
#include "triton/Dialect/Triton/IR/Dialect.h"
98
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -17,9 +16,23 @@ namespace ttg = mlir::triton::gpu;
1716
// Utility functions
1817
//===----------------------------------------------------------------------===//
1918

20-
// Return true if the given moduleOp contains a pure matmul problem; i.e.,
21-
// single dot in the main loop.
22-
static bool isPureMatmulProblem(triton::FuncOp funcOp) {
19+
static SmallVector<scf::ForOp> getLeafForOps(triton::FuncOp funcOp) {
20+
SmallVector<scf::ForOp> allOps;
21+
funcOp->walk([&](scf::ForOp forOp) { allOps.push_back(forOp); });
22+
23+
SmallVector<scf::ForOp> leafOps;
24+
for (scf::ForOp forOp : allOps) {
25+
auto searchResult = forOp.getBody()->walk(
26+
[](scf::ForOp) { return WalkResult::interrupt(); });
27+
if (!searchResult.wasInterrupted())
28+
leafOps.push_back(forOp);
29+
}
30+
return leafOps;
31+
}
32+
33+
// Return true if the given funcOp is a pure matmul problem; i.e.,
34+
// a single main loop with a single dot.
35+
static bool isPureMatmulFunc(triton::FuncOp funcOp) {
2336
bool isMatmul = true;
2437
bool foundLoop = false;
2538
funcOp.walk([&](scf::ForOp forOp) -> void {
@@ -31,6 +44,20 @@ static bool isPureMatmulProblem(triton::FuncOp funcOp) {
3144
return foundLoop && isMatmul;
3245
}
3346

47+
// Return true if the given ForOp contains a pure matmul problem; i.e.,
48+
// single dot and at least 2 glboal loads in the main loop.
49+
static bool isPureMatmulLoop(scf::ForOp forOp) {
50+
int dotCounter = 0;
51+
int loadCounter = 0;
52+
forOp.walk([&](Operation *op) {
53+
if (isa<triton::DotOp>(op))
54+
++dotCounter;
55+
else if (isa<triton::LoadOp>(op))
56+
++loadCounter;
57+
});
58+
return dotCounter == 1 && loadCounter >= 2;
59+
}
60+
3461
// Search through block to find earliest insertion point for move op. This can
3562
// be either an atomic op or last usage of source pointer. Search ends when move
3663
// op is encountered.
@@ -214,14 +241,41 @@ static void moveUpTranspose(triton::FuncOp funcOp) {
214241
}
215242

216243
// Schedule global load and local store ops for better GEMM performance.
217-
static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) {
244+
static void scheduleGlobalLoadLocalStore(Operation *parentOp) {
218245
SmallVector<Operation *> moveOps;
219-
// Move local_stores early if dependence distance greater than one iteration.
220-
// Best perf on GEMM when these precede global loads.
221-
funcOp.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); });
222-
// Move global loads early to prefetch. This may increase register pressure
223-
// but it enables issuing global loads early.
224-
funcOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); });
246+
247+
// Search through the forOp initArgs to find global loads for a GEMM that
248+
// the pipeliner may have peeled into a loop prologue.
249+
if (auto forOp = dyn_cast<scf::ForOp>(parentOp)) {
250+
SmallVector<Value> vals = forOp.getInitArgs();
251+
while (!vals.empty()) {
252+
SmallVector<Value> nextVals; // Next set of values to search via BFS.
253+
for (size_t i = 0; i < vals.size(); ++i) {
254+
Operation *defOp = vals[i].getDefiningOp();
255+
if (isa_and_nonnull<triton::LoadOp>(defOp)) {
256+
moveOps.push_back(defOp);
257+
continue;
258+
}
259+
260+
// Find uses of the op that are local_store
261+
for (Operation *op : vals[i].getUsers()) {
262+
if (auto storeOp = dyn_cast<ttg::LocalStoreOp>(op)) {
263+
// Recurse on operands of the local_store (to find a global_load).
264+
nextVals.push_back(storeOp.getSrc());
265+
}
266+
}
267+
}
268+
vals.swap(nextVals);
269+
}
270+
}
271+
272+
// Move local_store ops inside the loop early if dependence distance greater
273+
// than one iteration (i.e., num_stages > 2). For such case, better perf on
274+
// GEMM when local_store ops precede global loads.
275+
parentOp->walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); });
276+
// Move global_load ops inside the loop early to prefetch. This may increase
277+
// register pressure but it enables issuing global loads early.
278+
parentOp->walk([&](triton::LoadOp op) { moveOps.push_back(op); });
225279

226280
for (auto op : llvm::reverse(moveOps)) {
227281
// Gather use-def chain in block.
@@ -314,38 +368,36 @@ static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) {
314368
// are experimenting how to better control instruction scheduling and enable
315369
// such optimizations.
316370
//===-------------------------------------------------------------------===//
317-
static void sinkSecondLoad(triton::FuncOp funcOp) {
318-
funcOp.walk([&](scf::ForOp forOp) -> void {
319-
SetVector<triton::LoadOp> loadOps;
320-
triton::DotOp dotOp;
321-
for (Operation &op : forOp) {
322-
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
323-
loadOps.insert(loadOp);
324-
if (auto curOp = dyn_cast<triton::DotOp>(&op))
325-
dotOp = curOp;
326-
}
327-
// Only apply the optimization when there are 2 load's in the loop
328-
if (loadOps.size() != 2)
329-
return;
330-
// Only apply the optimization when tile size is large enough
331-
// 1. nonKDim >= 128
332-
// 2. kDim >= 64
333-
auto ldAOp = loadOps[0];
334-
auto tileAShape = cast<RankedTensorType>(ldAOp.getType()).getShape();
335-
auto ldBOp = loadOps[1];
336-
auto tileBShape = cast<RankedTensorType>(ldBOp.getType()).getShape();
337-
if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128))
338-
return;
339-
// Only apply the optimization when the moving is legal
340-
// 1. Make sure the 2nd loadOp is before the dot
341-
// 2. Make sure the first user of the 2nd loadOp is after the dot.
342-
bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp);
343-
auto firstUser = *ldBOp.getResult().getUsers().begin();
344-
bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser);
345-
if (isBeforeDotOp && firstUserAfterDotOp)
346-
// move ldBOp right before tt.dot
347-
ldBOp->moveBefore(dotOp);
348-
});
371+
static void sinkSecondLoad(scf::ForOp forOp) {
372+
SetVector<triton::LoadOp> loadOps;
373+
triton::DotOp dotOp;
374+
for (Operation &op : forOp) {
375+
if (auto loadOp = dyn_cast<triton::LoadOp>(&op))
376+
loadOps.insert(loadOp);
377+
if (auto curOp = dyn_cast<triton::DotOp>(&op))
378+
dotOp = curOp;
379+
}
380+
// Only apply the optimization when there are 2 load's in the loop
381+
if (loadOps.size() != 2)
382+
return;
383+
// Only apply the optimization when tile size is large enough
384+
// 1. nonKDim >= 128
385+
// 2. kDim >= 64
386+
auto ldAOp = loadOps[0];
387+
auto tileAShape = cast<RankedTensorType>(ldAOp.getType()).getShape();
388+
auto ldBOp = loadOps[1];
389+
auto tileBShape = cast<RankedTensorType>(ldBOp.getType()).getShape();
390+
if (!(tileAShape[0] >= 128 && tileAShape[1] >= 64 && tileBShape[1] >= 128))
391+
return;
392+
// Only apply the optimization when the moving is legal
393+
// 1. Make sure the 2nd loadOp is before the dot
394+
// 2. Make sure the first user of the 2nd loadOp is after the dot.
395+
bool isBeforeDotOp = ldBOp->isBeforeInBlock(dotOp);
396+
auto firstUser = *ldBOp.getResult().getUsers().begin();
397+
bool firstUserAfterDotOp = dotOp->isBeforeInBlock(firstUser);
398+
if (isBeforeDotOp && firstUserAfterDotOp)
399+
// move ldBOp right before tt.dot
400+
ldBOp->moveBefore(dotOp);
349401
}
350402

351403
//===----------------------------------------------------------------------===//
@@ -369,9 +421,17 @@ struct TritonAMDGPUReorderInstructionsPass
369421

370422
moveUpTranspose(funcOp);
371423

372-
if (isPureMatmulProblem(funcOp)) {
424+
if (isPureMatmulFunc(funcOp)) {
373425
scheduleGlobalLoadLocalStore(funcOp);
374-
sinkSecondLoad(funcOp);
426+
funcOp.walk([&](scf::ForOp forOp) -> void { sinkSecondLoad(forOp); });
427+
} else {
428+
SmallVector<scf::ForOp> leafForOps = getLeafForOps(funcOp);
429+
for (auto forOp : leafForOps) {
430+
if (isPureMatmulLoop(forOp)) {
431+
scheduleGlobalLoadLocalStore(forOp);
432+
sinkSecondLoad(forOp);
433+
}
434+
}
375435
}
376436
}
377437
}

0 commit comments

Comments
 (0)