Skip to content

Commit 7873637

Browse files
authored
[AMD] NFC: change to func walk in ReorderInstructions (#5131)
This avoids us to scan across functions when considering optimizations. Right now we have one function in the module so it's NFC; but it should be cleaner this way.
1 parent 915c149 commit 7873637

File tree

2 files changed

+64
-24
lines changed

2 files changed

+64
-24
lines changed

test/TritonGPU/amd/amd-sched-2nd-load.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
4949
}
5050
}
5151

52+
// -----
53+
54+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
55+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
56+
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
57+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
58+
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
59+
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
60+
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
61+
5262
// Should apply: tile size 256x256x64 with single dot
5363
// CHECK-LABEL: sink_2nd_load_256x256x64
5464
// CHECK: %[[tileA:.*]] = tt.load
@@ -78,6 +88,16 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
7888
}
7989
}
8090

91+
// -----
92+
93+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
94+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
95+
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
96+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
97+
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
98+
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
99+
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
100+
81101
// Should NOT apply: tile size 256x64x128 with single dot
82102
// CHECK-LABEL: sink_2nd_load_256x64x128
83103
// CHECK: %[[tileA:.*]] = tt.load
@@ -107,6 +127,16 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
107127
}
108128
}
109129

130+
// -----
131+
132+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
133+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
134+
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
135+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
136+
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
137+
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
138+
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
139+
110140
// Should NOT apply: tile size 256x256x32 with single dot
111141
// CHECK-LABEL: sink_2nd_load_256x256x32
112142
// CHECK: %[[tileA:.*]] = tt.load
@@ -136,6 +166,15 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
136166
}
137167
}
138168

169+
// -----
170+
171+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
172+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
173+
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
174+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
175+
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
176+
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
177+
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
139178

140179
// Category 2: single dot with two loads and tile size is large enough (128x128x128).
141180
// We make sure the move is legal.

third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ namespace ttg = mlir::triton::gpu;
1919

2020
// Return true if the given moduleOp contains a pure matmul problem; i.e.,
2121
// single dot in the main loop.
22-
static bool isPureMatmulProblem(ModuleOp moduleOp) {
22+
static bool isPureMatmulProblem(triton::FuncOp funcOp) {
2323
bool isMatmul = true;
2424
bool foundLoop = false;
25-
moduleOp.walk([&](scf::ForOp forOp) -> void {
25+
funcOp.walk([&](scf::ForOp forOp) -> void {
2626
int counter = 0;
2727
forOp.walk([&counter](triton::DotOp dotOp) { ++counter; });
2828
isMatmul = (isMatmul && (counter == 1));
@@ -98,9 +98,9 @@ static bool isCrossLoopBoundary(mlir::Operation *opInsideLoop,
9898

9999
// Sink dot layout conversions into loops to decrease register pressure when
100100
// possible.
101-
static void sinkDotConversion(ModuleOp moduleOp) {
101+
static void sinkDotConversion(triton::FuncOp funcOp) {
102102
DenseMap<Operation *, Operation *> opToMove;
103-
moduleOp.walk([&](ttg::ConvertLayoutOp op) {
103+
funcOp.walk([&](ttg::ConvertLayoutOp op) {
104104
Attribute encoding = op.getType().getEncoding();
105105
if (!isa_and_nonnull<ttg::DotOperandEncodingAttr>(encoding))
106106
return;
@@ -139,8 +139,8 @@ static void sinkDotConversion(ModuleOp moduleOp) {
139139
// %2 = local_alloc
140140
// %3 = local_store %1, %2
141141
// %4 = local_load %2
142-
static void hoistLocalLoad(ModuleOp moduleOp) {
143-
moduleOp.walk([&](ttg::LocalLoadOp localLoad) {
142+
static void hoistLocalLoad(triton::FuncOp funcOp) {
143+
funcOp.walk([&](ttg::LocalLoadOp localLoad) {
144144
auto localAlloc = localLoad.getSrc().getDefiningOp<ttg::LocalAllocOp>();
145145
if (!localAlloc)
146146
return;
@@ -190,9 +190,9 @@ static void hoistLocalLoad(ModuleOp moduleOp) {
190190

191191
// Sink conversion after the last dealloc but before the first use in its block.
192192
// This helps to avoid unnecessary shared memory allocation.
193-
static void moveDownCoversion(ModuleOp moduleOp) {
193+
static void moveDownCoversion(triton::FuncOp funcOp) {
194194
SmallVector<ttg::ConvertLayoutOp> convertOps;
195-
moduleOp.walk([&](ttg::ConvertLayoutOp op) { convertOps.push_back(op); });
195+
funcOp.walk([&](ttg::ConvertLayoutOp op) { convertOps.push_back(op); });
196196

197197
for (auto op : convertOps) {
198198
Operation *user = getFirstUseInSameBlock(op);
@@ -204,24 +204,24 @@ static void moveDownCoversion(ModuleOp moduleOp) {
204204
}
205205

206206
// Move transpositions just after their definition.
207-
static void moveUpTranspose(ModuleOp moduleOp) {
207+
static void moveUpTranspose(triton::FuncOp funcOp) {
208208
SmallVector<triton::TransOp> transOps;
209-
moduleOp.walk([&](triton::TransOp op) { transOps.push_back(op); });
209+
funcOp.walk([&](triton::TransOp op) { transOps.push_back(op); });
210210

211211
for (auto op : transOps)
212212
if (Operation *argOp = op.getSrc().getDefiningOp())
213213
op->moveAfter(argOp);
214214
}
215215

216216
// Schedule global load and local store ops for better GEMM performance.
217-
static void scheduleGlobalLoadLocalStore(ModuleOp m) {
217+
static void scheduleGlobalLoadLocalStore(triton::FuncOp funcOp) {
218218
SmallVector<Operation *> moveOps;
219219
// Move global loads early to prefetch. This may increase register pressure
220220
// but it enables issuing global loads early.
221-
m.walk([&](triton::LoadOp op) { moveOps.push_back(op); });
221+
funcOp.walk([&](triton::LoadOp op) { moveOps.push_back(op); });
222222
// Move local_stores early if dependence distance greater than one iteration.
223223
// Best perf on GEMM when these precede global loads.
224-
m.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); });
224+
funcOp.walk([&](ttg::LocalStoreOp op) { moveOps.push_back(op); });
225225

226226
for (auto op : llvm::reverse(moveOps)) {
227227
// Gather use-def chain in block.
@@ -311,8 +311,8 @@ static void scheduleGlobalLoadLocalStore(ModuleOp m) {
311311
* are experimenting how to better control instruction scheduling and enable
312312
* such optimizations.
313313
*/
314-
static void sinkSecondLoad(ModuleOp m) {
315-
m.walk([&](scf::ForOp forOp) -> void {
314+
static void sinkSecondLoad(triton::FuncOp funcOp) {
315+
funcOp.walk([&](scf::ForOp forOp) -> void {
316316
SetVector<triton::LoadOp> loadOps;
317317
triton::DotOp dotOp;
318318
for (Operation &op : forOp) {
@@ -358,18 +358,19 @@ struct TritonAMDGPUReorderInstructionsPass
358358
TritonAMDGPUReorderInstructionsPass> {
359359
void runOnOperation() override {
360360
ModuleOp m = getOperation();
361+
for (auto funcOp : m.getOps<triton::FuncOp>()) {
362+
hoistLocalLoad(funcOp);
361363

362-
hoistLocalLoad(m);
364+
sinkDotConversion(funcOp);
365+
moveDownCoversion(funcOp);
363366

364-
sinkDotConversion(m);
365-
moveDownCoversion(m);
367+
moveUpTranspose(funcOp);
366368

367-
moveUpTranspose(m);
368-
369-
if (isPureMatmulProblem(m))
370-
sinkSecondLoad(m);
371-
else
372-
scheduleGlobalLoadLocalStore(m);
369+
if (isPureMatmulProblem(funcOp))
370+
sinkSecondLoad(funcOp);
371+
else
372+
scheduleGlobalLoadLocalStore(funcOp);
373+
}
373374
}
374375
};
375376
} // namespace

0 commit comments

Comments
 (0)