Skip to content

Commit 81946f9

Browse files
authored
[optimize-dot-operands]: Fuse load and trans operations - part 3 (#4537)
Enhance the transformation to allow multiple `load+transpose` fusion opportunities in separate for loops when the def-use chains corresponding to the opportunities originate at the same `make_tensor_ptr` operation. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent ef44941 commit 81946f9

File tree

3 files changed

+364
-74
lines changed

3 files changed

+364
-74
lines changed

python/tutorials/06-fused-attention.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def is_blackwell():
4040
return is_cuda() and torch.cuda.get_device_capability()[0] == 10
4141

4242

43-
# FIXME: Revert temporary source code modification done in last commit of PR #4399.
43+
# FIXME: Revert temporary source code modification (only for fp8) done in last commit of PR #4399.
44+
# Note: Triton will fuse load+trans operations, when the data type is fp8, 2D block read aren't generated
45+
# yet because DPAS doesn't natively support fp8. We have to enhance that part of the code generation
46+
# in order to remove the remaining source code changes.
4447

4548

4649
@triton.jit
@@ -68,7 +71,10 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
6871
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
6972
start_n = tl.multiple_of(start_n, BLOCK_N)
7073
# -- compute qk ----
71-
k = desc_k.load([0, offsetk_y])
74+
if dtype == tl.float8e5:
75+
k = desc_k.load([0, offsetk_y])
76+
else:
77+
k = desc_k.load([offsetk_y, 0]).T
7278
qk = tl.dot(q, k)
7379
if STAGE == 2:
7480
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
@@ -192,8 +198,12 @@ def _attn_fwd(sm_scale, M, #
192198
else:
193199
desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
194200
block_shape=[BLOCK_N, HEAD_DIM])
195-
desc_k = _maybe_make_tensor_desc(desc_k, shape=[HEAD_DIM, y_dim], strides=[1, HEAD_DIM],
196-
block_shape=[HEAD_DIM, BLOCK_N])
201+
if FP8_OUTPUT:
202+
desc_k = _maybe_make_tensor_desc(desc_k, shape=[HEAD_DIM, y_dim], strides=[1, HEAD_DIM],
203+
block_shape=[HEAD_DIM, BLOCK_N])
204+
else:
205+
desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
206+
block_shape=[BLOCK_N, HEAD_DIM])
197207
desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
198208
block_shape=[BLOCK_M, HEAD_DIM])
199209

test/TritonIntelGPU/dot-operands.mlir

Lines changed: 54 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
7070
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
7171
// COM: tt.load -> tt.trans -> tt.dot chain, in a loop.
7272
// COM: where the 'make_tensor_ptr' result is loop carried.
73-
tt.func public @fuseLoadWithTrans3(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) {
74-
%c4_i32 = arith.constant 4 : i32
73+
tt.func public @fuseLoadWithTrans3(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
7574
%c1024_i32 = arith.constant 1024 : i32
7675
%c0_i32 = arith.constant 0 : i32
7776
%c32_i32 = arith.constant 32 : i32
@@ -80,15 +79,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
8079
%c1_i64 = arith.constant 1 : i64
8180
%c1024_i64 = arith.constant 1024 : i64
8281
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
83-
%0 = tt.get_program_id x : i32
84-
%1 = arith.divsi %0, %c16_i32 : i32
85-
%2 = arith.muli %1, %c4_i32 : i32
86-
%3 = arith.subi %c4_i32, %2 : i32
87-
%4 = arith.minsi %3, %c4_i32 : i32
88-
%5 = arith.remsi %0, %c16_i32 : i32
89-
%6 = arith.remsi %5, %4 : i32
90-
%7 = arith.addi %2, %6 : i32
91-
%8 = arith.divsi %5, %4 : i32
9282
%9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
9383
%10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #linear>>
9484
%13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>) : i32 {
@@ -116,13 +106,61 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
116106

117107
// -----
118108

109+
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [0, 32]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [0, 0]], block = []}>
110+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
111+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} {
112+
// COM: tt.load -> tt.trans -> tt.dot chain, in 2 loops.
113+
// COM: Where the block ptr used by the loads in the 2 loops is created by the same make_tensor_ptr operation.
114+
tt.func public @fuseLoadWithTrans4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
115+
%c32_i32 = arith.constant 32 : i32
116+
%c0_i32 = arith.constant 0 : i32
117+
%c64_i64 = arith.constant 64 : i64
118+
%c1_i64 = arith.constant 1 : i64
119+
%cst_3 = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma>
120+
%7 = tt.make_tensor_ptr %arg1, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
121+
%9 = tt.make_tensor_ptr %arg2, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #linear>>
122+
%24 = tt.advance %7, [%arg0, %c0_i32] : <tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
123+
%25 = tt.load %24 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
124+
%29:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg13 = %arg0) -> (i32) : i32 {
125+
%adv1 = tt.advance %9, [%arg13, %c0_i32] : <tensor<32x64xf16, #linear>>
126+
%load1 = tt.load %adv1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #linear>>
127+
%trans1 = tt.trans %load1 {order = array<i32: 1, 0>} : tensor<32x64xf16, #linear> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
128+
%dot1 = tt.dot %25, %trans1, %cst_3, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
129+
%76 = arith.addi %arg13, %c32_i32 : i32
130+
scf.yield %76 : i32
131+
}
132+
%38:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg13 = %arg0) -> (i32) : i32 {
133+
%adv2 = tt.advance %9, [%arg13, %c0_i32] : <tensor<32x64xf16, #linear>>
134+
%load2 = tt.load %adv2 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #linear>>
135+
%trans2 = tt.trans %load2 {order = array<i32: 1, 0>} : tensor<32x64xf16, #linear> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
136+
%dot2 = tt.dot %25, %trans2, %cst_3, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
137+
%81 = arith.addi %arg13, %c32_i32 : i32
138+
scf.yield %81 : i32
139+
}
140+
tt.return
141+
}
142+
// CHECK-LABEL: fuseLoadWithTrans4
143+
// CHECK-NOT: tt.trans
144+
// CHECK-COUNT-2: tt.make_tensor_ptr %arg2, [%c64_i64, %c1_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
145+
// CHECK: scf.for {{.*}}
146+
// CHECK: [[ADV1:%.*]] = tt.advance {{.*}}, {{.*}} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
147+
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV1]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
148+
// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
149+
// CHECK: scf.yield {{.*}}
150+
// CHECK: scf.for {{.*}}
151+
// CHECK: [[ADV2:%.*]] = tt.advance {{.*}}, {{.*}} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
152+
// CHECK: [[LOAD_B2:%.*]] = tt.load [[ADV2]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
153+
// CHECK: tt.dot {{.*}}, [[LOAD_B2]], {{.*}}, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
154+
// CHECK: scf.yield {{.*}}
155+
}
156+
157+
// -----
158+
119159
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}>
120160
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
121161
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
122-
// COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load
123-
// COM: that 'feeds' the transpose operation is used.
124-
tt.func public @doNotFuseLoadWithTrans1(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) {
125-
%c4_i32 = arith.constant 4 : i32
162+
// COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load that 'feeds' the transpose operation is used.
163+
tt.func public @doNotFuseLoadWithTrans1(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
126164
%c1024_i32 = arith.constant 1024 : i32
127165
%c0_i32 = arith.constant 0 : i32
128166
%c32_i32 = arith.constant 32 : i32
@@ -131,15 +169,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
131169
%c1_i64 = arith.constant 1 : i64
132170
%c1024_i64 = arith.constant 1024 : i64
133171
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
134-
%0 = tt.get_program_id x : i32
135-
%1 = arith.divsi %0, %c16_i32 : i32
136-
%2 = arith.muli %1, %c4_i32 : i32
137-
%3 = arith.subi %c4_i32, %2 : i32
138-
%4 = arith.minsi %3, %c4_i32 : i32
139-
%5 = arith.remsi %0, %c16_i32 : i32
140-
%6 = arith.remsi %5, %4 : i32
141-
%7 = arith.addi %2, %6 : i32
142-
%8 = arith.divsi %5, %4 : i32
143172
%9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
144173
%10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #linear>>
145174
%13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>) : i32 {
@@ -166,7 +195,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
166195
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
167196
// COM: Ensure load is not fused with transpose if there are multiple users of an operation in the def-use chain containing the load + transpose.
168197
// COM: In this case `%19` is used by the load that feeds the transpose and by a store operation.
169-
tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) {
198+
tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
170199
%c4_i32 = arith.constant 4 : i32
171200
%c1024_i32 = arith.constant 1024 : i32
172201
%c0_i32 = arith.constant 0 : i32

0 commit comments

Comments
 (0)