Skip to content
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
0ca51af
[WIP]: Fuse load and trans operations
etiotto Jun 4, 2025
cd406fb
Merge remote-tracking branch 'origin/main' into etiotto.merge_load_wi…
etiotto Jun 5, 2025
07599e3
Limit candidates to operations with no associated region.
etiotto Jun 5, 2025
b1a2c1f
Allow candidates in for loop
etiotto Jun 6, 2025
5eafc6b
Fix precommit
etiotto Jun 6, 2025
0422ad6
Merge branch 'main' into etiotto.merge_load_with_trans.2
etiotto Jun 6, 2025
5181bb3
Better traces
etiotto Jun 6, 2025
2329dd7
Allow fusing load+trans when load ptr is loop carried
etiotto Jun 9, 2025
475eef7
Fix failing tutorial 09
etiotto Jun 10, 2025
a2fa44c
Merge remote-tracking branch 'origin/main' into etiotto.merge_load_wi…
etiotto Jun 11, 2025
617dc0d
Allow trans user to be any operation as long as def-use chain end is …
etiotto Jun 12, 2025
dd8979d
Address code review comments
etiotto Jun 17, 2025
d3cb92b
Address code review comments
etiotto Jun 17, 2025
c1a6949
Address code review comments
etiotto Jun 18, 2025
e344c13
Allow trans user to be any operation as long as def-use chain end is …
etiotto Jun 19, 2025
e7d0d74
Merge remote-tracking branch 'origin/main' into etiotto.merge_load_wi…
etiotto Jun 20, 2025
8e6ee3e
Merge remote-tracking branch 'origin/main' into etiotto.merge_load_wi…
etiotto Jun 20, 2025
53b26ec
Fix precommit
etiotto Jun 20, 2025
23a1ef5
Enable tutorial 06 with tt.trans when data type is not fp8
etiotto Jun 20, 2025
ff7a5a8
Merge remote-tracking branch 'origin/main' into etiotto.merge_load_wi…
etiotto Jun 23, 2025
79ee3e6
Merge remote-tracking branch 'origin/main' into etiotto.merge_load_wi…
etiotto Jun 23, 2025
22db300
Address code review comments
etiotto Jun 24, 2025
0745f28
Address code review comments
etiotto Jun 25, 2025
0d68f06
Address code review comments
etiotto Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions python/tutorials/06-fused-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,10 @@ def is_blackwell():
return is_cuda() and torch.cuda.get_device_capability()[0] == 10


# FIXME: Revert temporary source code modification done in last commit of PR #4399.
# FIXME: Revert temporary source code modification (only for fp8) done in last commit of PR #4399.
# Note: Triton will fuse load+trans operations, when the data type is fp8, 2D block read aren't generated
# yet because DPAS doesn't natively support fp8. We have to enhance that part of the code generation
# in order to remove the remaining source code changes.


@triton.jit
Expand Down Expand Up @@ -68,7 +71,10 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
for start_n in tl.range(lo, hi, BLOCK_N, warp_specialize=warp_specialize):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = desc_k.load([0, offsetk_y])
if dtype == tl.float8e5:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For fp16 we undo the source code changes we made and the code is now back to the original. For FP8 we keep the source code changes until we can issue DPAS instructions for them (after making 2 fp8 elems into a fp16).

k = desc_k.load([0, offsetk_y])
else:
k = desc_k.load([offsetk_y, 0]).T
qk = tl.dot(q, k)
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
Expand Down Expand Up @@ -192,8 +198,12 @@ def _attn_fwd(sm_scale, M, #
else:
desc_v = _maybe_make_tensor_desc(desc_v, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_N, HEAD_DIM])
desc_k = _maybe_make_tensor_desc(desc_k, shape=[HEAD_DIM, y_dim], strides=[1, HEAD_DIM],
block_shape=[HEAD_DIM, BLOCK_N])
if FP8_OUTPUT:
desc_k = _maybe_make_tensor_desc(desc_k, shape=[HEAD_DIM, y_dim], strides=[1, HEAD_DIM],
block_shape=[HEAD_DIM, BLOCK_N])
else:
desc_k = _maybe_make_tensor_desc(desc_k, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_N, HEAD_DIM])
desc_o = _maybe_make_tensor_desc(desc_o, shape=[y_dim, HEAD_DIM], strides=[HEAD_DIM, 1],
block_shape=[BLOCK_M, HEAD_DIM])

Expand Down
79 changes: 54 additions & 25 deletions test/TritonIntelGPU/dot-operands.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
// COM: tt.load -> tt.trans -> tt.dot chain, in a loop.
// COM: where the 'make_tensor_ptr' result is loop carried.
tt.func public @fuseLoadWithTrans3(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) {
%c4_i32 = arith.constant 4 : i32
tt.func public @fuseLoadWithTrans3(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
%c1024_i32 = arith.constant 1024 : i32
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
Expand All @@ -80,15 +79,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
%c1_i64 = arith.constant 1 : i64
%c1024_i64 = arith.constant 1024 : i64
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%0 = tt.get_program_id x : i32
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just making the test simpler here

%1 = arith.divsi %0, %c16_i32 : i32
%2 = arith.muli %1, %c4_i32 : i32
%3 = arith.subi %c4_i32, %2 : i32
%4 = arith.minsi %3, %c4_i32 : i32
%5 = arith.remsi %0, %c16_i32 : i32
%6 = arith.remsi %5, %4 : i32
%7 = arith.addi %2, %6 : i32
%8 = arith.divsi %5, %4 : i32
%9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
%10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #linear>>
%13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>) : i32 {
Expand Down Expand Up @@ -116,13 +106,61 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {

// -----

#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [0, 32]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [0, 0]], block = []}>
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} {
// COM: tt.load -> tt.trans -> tt.dot chain, in 2 loops.
// COM: Where the block ptr used by the loads in the 2 loops is created by the same make_tensor_ptr operation.
tt.func public @fuseLoadWithTrans4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) {
%c32_i32 = arith.constant 32 : i32
%c0_i32 = arith.constant 0 : i32
%c64_i64 = arith.constant 64 : i64
%c1_i64 = arith.constant 1 : i64
%cst_3 = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma>
%7 = tt.make_tensor_ptr %arg1, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
%9 = tt.make_tensor_ptr %arg2, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #linear>>
%24 = tt.advance %7, [%arg0, %c0_i32] : <tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
%25 = tt.load %24 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
%29:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg13 = %arg0) -> (i32) : i32 {
%adv1 = tt.advance %9, [%arg13, %c0_i32] : <tensor<32x64xf16, #linear>>
%load1 = tt.load %adv1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #linear>>
%trans1 = tt.trans %load1 {order = array<i32: 1, 0>} : tensor<32x64xf16, #linear> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%dot1 = tt.dot %25, %trans1, %cst_3, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
%76 = arith.addi %arg13, %c32_i32 : i32
scf.yield %76 : i32
}
%38:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg13 = %arg0) -> (i32) : i32 {
%adv2 = tt.advance %9, [%arg13, %c0_i32] : <tensor<32x64xf16, #linear>>
%load2 = tt.load %adv2 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #linear>>
%trans2 = tt.trans %load2 {order = array<i32: 1, 0>} : tensor<32x64xf16, #linear> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%dot2 = tt.dot %25, %trans2, %cst_3, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
%81 = arith.addi %arg13, %c32_i32 : i32
scf.yield %81 : i32
}
tt.return
}
// CHECK-LABEL: fuseLoadWithTrans4
// CHECK-NOT: tt.trans
// CHECK-COUNT-2: tt.make_tensor_ptr %arg2, [%c64_i64, %c1_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
// CHECK: scf.for {{.*}}
// CHECK: [[ADV1:%.*]] = tt.advance {{.*}}, {{.*}} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV1]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
// CHECK: scf.yield {{.*}}
// CHECK: scf.for {{.*}}
// CHECK: [[ADV2:%.*]] = tt.advance {{.*}}, {{.*}} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
// CHECK: [[LOAD_B2:%.*]] = tt.load [[ADV2]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>>
// CHECK: tt.dot {{.*}}, [[LOAD_B2]], {{.*}}, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma>
// CHECK: scf.yield {{.*}}
}

// -----

#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}>
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
// COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load
// COM: that 'feeds' the transpose operation is used.
tt.func public @doNotFuseLoadWithTrans1(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) {
%c4_i32 = arith.constant 4 : i32
// COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load that 'feeds' the transpose operation is used.
tt.func public @doNotFuseLoadWithTrans1(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
%c1024_i32 = arith.constant 1024 : i32
%c0_i32 = arith.constant 0 : i32
%c32_i32 = arith.constant 32 : i32
Expand All @@ -131,15 +169,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
%c1_i64 = arith.constant 1 : i64
%c1024_i64 = arith.constant 1024 : i64
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma>
%0 = tt.get_program_id x : i32
%1 = arith.divsi %0, %c16_i32 : i32
%2 = arith.muli %1, %c4_i32 : i32
%3 = arith.subi %c4_i32, %2 : i32
%4 = arith.minsi %3, %c4_i32 : i32
%5 = arith.remsi %0, %c16_i32 : i32
%6 = arith.remsi %5, %4 : i32
%7 = arith.addi %2, %6 : i32
%8 = arith.divsi %5, %4 : i32
%9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>>
%10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #linear>>
%13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>) : i32 {
Expand All @@ -166,7 +195,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} {
// COM: Ensure load is not fused with transpose if there are multiple users of an operation in the def-use chain containing the load + transpose.
// COM: In this case `%19` is used by the load that feeds the transpose and by a store operation.
tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) {
tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) {
%c4_i32 = arith.constant 4 : i32
%c1024_i32 = arith.constant 1024 : i32
%c0_i32 = arith.constant 0 : i32
Expand Down
Loading