Skip to content

Commit 037728b

Browse files
oplavsicOgnjen Plavsic
andauthored
[AMD] Fix "keep Q tensor in VGPRS" optimization (#4901)
Adjust the placement of LDS writes and reads to immediately follow the definition of their operands in case where LDS write is in the loop but it's operand is not. This is a heuristic for optimizing fused attention by hoisting Q tensor LDS read/write operations outside of the loop, as Q is a loop invariant and can be loaded once before entering the loop. In the previous implementation, the heuristic incorrectly assumed that the operand of the LDS write had to be a load operation, which is unnecessary. Additionally, there was no explicit check to verify whether the LDS write was in the loop while its defining operand was not. This PR addresses both issues. --------- Co-authored-by: Ognjen Plavsic <[email protected]>
1 parent fb90385 commit 037728b

File tree

2 files changed

+129
-31
lines changed

2 files changed

+129
-31
lines changed

test/TritonGPU/amd/amd-reorder-instructions.mlir

Lines changed: 91 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,100 @@
11
// RUN: triton-opt %s -split-input-file -tritonamdgpu-reorder-instructions | FileCheck %s
22

3-
// Check that we order load, local_alloc, local_store (optional) and local_load one after another. This is useful
4-
// for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers
3+
// Check that we place local_alloc, local_store (optional) and local_load right after definition of their operands
4+
// in cases where local_alloc is in the loop but it's operand is not.
5+
// This is useful for making sure that Q tensor in FA is hoisted out of the main loop and kept in registers
56
// throughout the computation.
6-
// CHECK-LABEL: order_load_alloc_local_load
7-
// CHECK: %[[LOAD:.+]] = tt.load
8-
// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[LOAD]]
9-
// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]]
10-
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
11-
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
12-
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
13-
module attributes {"triton_gpu.num-warps" = 8 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
14-
tt.func public @order_load_alloc_local_load(%arg0: tensor<32x32x!tt.ptr<f32>, #blocked>) attributes {noinline = false} {
15-
%9 = tt.load %arg0 : tensor<32x32x!tt.ptr<f32>, #blocked>
16-
%cst = arith.constant dense<0.000000e+00> : tensor<32x32xf32, #mma>
17-
%10 = triton_gpu.local_alloc %9 : (tensor<32x32xf32, #blocked>) -> !tt.memdesc<32x32xf32, #shared>
18-
%cst_0 = arith.constant dense<1.230000e+02> : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
19-
%11 = triton_gpu.local_load %10 : !tt.memdesc<32x32xf32, #shared> -> tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
20-
%12 = tt.dot %11, %cst_0, %cst : tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x32xf32, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<32x32xf32, #mma>
21-
%13 = triton_gpu.convert_layout %12 : tensor<32x32xf32, #mma> -> tensor<32x32xf32, #blocked>
22-
tt.store %arg0, %13 : tensor<32x32x!tt.ptr<f32>, #blocked>
7+
8+
// CHECK-LABEL: hoist_q_out_of_the_loop
9+
// CHECK: %[[TRUNCF:.+]] = arith.truncf
10+
// CHECK-NEXT: %[[ALLOC:.+]] = triton_gpu.local_alloc %[[TRUNCF]]
11+
// CHECK-NEXT: triton_gpu.local_load %[[ALLOC]]
12+
// CHECK: scf.for
13+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
14+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
15+
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
16+
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
17+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
18+
tt.func public @hoist_q_out_of_the_loop(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
19+
%c0_i32 = arith.constant 0 : i32
20+
%cst = arith.constant 1.44269502 : f32
21+
%c128_i32 = arith.constant 128 : i32
22+
%c128_i64 = arith.constant 128 : i64
23+
%c0_i64 = arith.constant 0 : i64
24+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma>
25+
%1 = tt.get_program_id y : i32
26+
%2 = arith.muli %1, %arg7 : i32
27+
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32
28+
%12 = tt.splat %3 : !tt.ptr<f16> -> tensor<256x128x!tt.ptr<f16>, #blocked1>
29+
%41 = tt.load %12 : tensor<256x128x!tt.ptr<f16>, #blocked1>
30+
%42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1>
31+
%43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1>
32+
%44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1>
33+
%45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
34+
%54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 {
35+
%73 = tt.splat %3 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked2>
36+
%74 = tt.load %73 : tensor<128x128x!tt.ptr<f16>, #blocked2>
37+
%75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory>
38+
%76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
39+
%77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory>
40+
%78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
41+
%79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma>
42+
%107 = arith.addi %arg26, %c128_i64 : i64
43+
scf.yield %107 : i64
44+
} {tt.divisibility_arg1 = dense<128> : tensor<1xi32>}
2345
tt.return
2446
}
2547
}
48+
49+
50+
// -----
51+
// Check that reordering described in hoist_q_out_of_the_loop is not done in the case where both
52+
// local_alloc and it's src tensor defining op are in the loop.
53+
// CHECK-LABEL: no_hoist_q_type_reordering
54+
// CHECK: scf.for
55+
// CHECK: %[[TRUNCF:.+]] = arith.truncf
56+
// CHECK-NEXT: arith.constant
57+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0]}>
58+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
59+
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [1, 0], hasLeadingOffset = false}>
60+
#mfma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
61+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 8 : i32, triton_gpu.target = "hip:gfx90a", "triton_gpu.threads-per-warp" = 64 : i32} {
62+
tt.func public @no_hoist_q_type_reordering(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg20: i32 {tt.divisibility = 16 : i32}) attributes {noinline = false} {
63+
%c0_i32 = arith.constant 0 : i32
64+
%cst = arith.constant 1.44269502 : f32
65+
%c128_i32 = arith.constant 128 : i32
66+
%c128_i64 = arith.constant 128 : i64
67+
%c0_i64 = arith.constant 0 : i64
68+
%1 = tt.get_program_id y : i32
69+
%2 = arith.muli %1, %arg7 : i32
70+
%3 = tt.addptr %arg0, %2 : !tt.ptr<f16>, i32
71+
%12 = tt.splat %3 : !tt.ptr<f16> -> tensor<256x128x!tt.ptr<f16>, #blocked1>
72+
%41 = tt.load %12 : tensor<256x128x!tt.ptr<f16>, #blocked1>
73+
%42 = arith.extf %41 : tensor<256x128xf16, #blocked1> to tensor<256x128xf32, #blocked1>
74+
%43 = tt.splat %cst : f32 -> tensor<256x128xf32, #blocked1>
75+
%44 = arith.mulf %42, %43 : tensor<256x128xf32, #blocked1>
76+
%54:1 = scf.for %arg21 = %c0_i32 to %arg20 step %c128_i32 iter_args(%arg26 = %c0_i64) -> (i64) : i32 {
77+
%45 = arith.truncf %44 : tensor<256x128xf32, #blocked1> to tensor<256x128xf16, #blocked1>
78+
%cst_2 = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mfma>
79+
%73 = tt.splat %3 : !tt.ptr<f16> -> tensor<128x128x!tt.ptr<f16>, #blocked2>
80+
%74 = tt.load %73 : tensor<128x128x!tt.ptr<f16>, #blocked2>
81+
%75 = triton_gpu.local_alloc %45 : (tensor<256x128xf16, #blocked1>) -> !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory>
82+
%76 = triton_gpu.local_load %75 : !tt.memdesc<256x128xf16, #shared, #triton_gpu.shared_memory> -> tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>>
83+
%77 = triton_gpu.local_alloc %74 : (tensor<128x128xf16, #blocked2>) -> !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory>
84+
%78 = triton_gpu.local_load %77 : !tt.memdesc<128x128xf16, #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>, #triton_gpu.shared_memory> -> tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>>
85+
%79 = tt.dot %76, %78, %cst_2 : tensor<256x128xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mfma, kWidth = 4}>> * tensor<128x128xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mfma, kWidth = 4}>> -> tensor<256x128xf32, #mfma>
86+
%107 = arith.addi %arg26, %c128_i64 : i64
87+
scf.yield %107 : i64
88+
} {tt.divisibility_arg1 = dense<128> : tensor<1xi32>}
89+
tt.return
90+
}
91+
}
92+
93+
// -----
94+
#blocked = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [16, 4], warpsPerCTA = [1, 8], order = [0, 1]}>
95+
#mma = #triton_gpu.amd_mfma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [8, 1], instrShape = [32, 32], isTransposed = true}>
96+
#shared = #triton_gpu.shared<{vec = 4, perPhase = 1, maxPhase = 16, order = [0, 1], hasLeadingOffset = false}>
97+
2698
// CHECK-LABEL: order_load_alloc_local_load_local_store
2799
// CHECK: %[[LOAD:.+]] = tt.load
28100
// CHECK: %[[ALLOC:.+]] = triton_gpu.local_alloc

third_party/amd/lib/TritonAMDGPUTransforms/ReorderInstructions.cpp

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ findEarlyInsertionPoint(Block *block, Operation *move) {
6161
return ipnt;
6262
}
6363

64+
// Check if the operation opInsideLoop is inside any scf::ForOp and
65+
// opOutsideLoop is not inside the same loop.
66+
bool isCrossLoopBoundary(mlir::Operation *opInsideLoop,
67+
mlir::Operation *opOutsideLoop) {
68+
scf::ForOp parentForOp = opInsideLoop->getParentOfType<scf::ForOp>();
69+
return parentForOp && !parentForOp->isAncestor(opOutsideLoop);
70+
}
71+
6472
class TritonAMDGPUReorderInstructionsPass
6573
: public TritonAMDGPUReorderInstructionsBase<
6674
TritonAMDGPUReorderInstructionsPass> {
@@ -101,19 +109,28 @@ class TritonAMDGPUReorderInstructionsPass
101109
kv.first->moveBefore(kv.second);
102110
opToMove.clear();
103111

104-
// Move writing to LDS and reading from LDS right after the loading of a
105-
// tensor from global memory. There are 2 possible patterns depending on
106-
// whether writing to LDS is done using an optional local_alloc argument or
107-
// a local_store instruction:
112+
// Adjust the placement of LDS writes and reads to immediately follow the
113+
// definition of their operands in case where LDS write is in the
114+
// loop but it's operand is not. This is a heuristic for optimizing fused
115+
// attention by hoisting Q tensor LDS read/write operations outside of the
116+
// loop, as Q is a loop invariant and can be loaded once before entering the
117+
// loop.
118+
// There are two possible patterns for this adjustment depending on
119+
// whether the write to LDS is performed using an optional `local_alloc`
120+
// argument or a `local_store` instruction.
121+
//
122+
// clang-format off
108123
//
109-
// 1) %1 = load %ptr
124+
// 1) %1 = some_op ... (typically a load or an operation that scales the tensor after loading)
110125
// %2 = local_alloc %1
111126
// %3 = local_load %2
112127
//
113-
// 2) %1 = load %ptr
128+
// 2) %1 = some_op ...
114129
// %2 = local_alloc
115130
// %3 = local_store %1, %2
116131
// %4 = local_load %2
132+
//
133+
// clang-format on
117134
m.walk([&](ttg::LocalLoadOp localLoad) {
118135
auto localAlloc = localLoad.getSrc().getDefiningOp<ttg::LocalAllocOp>();
119136
if (!localAlloc)
@@ -123,10 +140,15 @@ class TritonAMDGPUReorderInstructionsPass
123140
if (localAlloc->getNumOperands() == 1) {
124141
if (!localAlloc->hasOneUse())
125142
return;
126-
auto loadOp = localAlloc->getOperand(0).getDefiningOp<tt::LoadOp>();
127-
if (!loadOp)
143+
144+
auto srcTensorOp = localAlloc->getOperand(0).getDefiningOp();
145+
// Check if localAlloc is in the loop but it's src tensor defining op is
146+
// outside of it.
147+
if (!srcTensorOp || !isCrossLoopBoundary(localAlloc, srcTensorOp)) {
128148
return;
129-
localAlloc->moveAfter(loadOp);
149+
}
150+
151+
localAlloc->moveAfter(srcTensorOp);
130152
localLoad->moveAfter(localAlloc);
131153
return;
132154
}
@@ -145,10 +167,14 @@ class TritonAMDGPUReorderInstructionsPass
145167
if (!isa<ttg::LocalStoreOp>(localStore))
146168
return;
147169

148-
auto loadOp = localStore->getOperand(0).getDefiningOp<tt::LoadOp>();
149-
if (!loadOp)
170+
auto srcTensorOp = localStore->getOperand(0).getDefiningOp();
171+
// Check if localStore is in the loop but it's src tensor defining op is
172+
// outside of it.
173+
if (!srcTensorOp || !isCrossLoopBoundary(localStore, srcTensorOp)) {
150174
return;
151-
localAlloc->moveAfter(loadOp);
175+
}
176+
177+
localAlloc->moveAfter(srcTensorOp);
152178
localStore->moveAfter(localAlloc);
153179
localLoad->moveAfter(localStore);
154180
});

0 commit comments

Comments
 (0)