-
Notifications
You must be signed in to change notification settings - Fork 69
[optimize-dot-operands]: Fuse load and trans operations - part 3 #4537
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 20 commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
0ca51af
[WIP]: Fuse load and trans operations
etiotto cd406fb
Merge remote-tracking branch 'origin/main' into etiotto.merge_load_wi…
etiotto 07599e3
Limit candidates to operations with no associated region.
etiotto b1a2c1f
Allow candidates in for loop
etiotto 5eafc6b
Fix precommit
etiotto 0422ad6
Merge branch 'main' into etiotto.merge_load_with_trans.2
etiotto 5181bb3
Better traces
etiotto 2329dd7
Allow fusing load+trans when load ptr is loop carried
etiotto 475eef7
Fix failing tutorial 09
etiotto a2fa44c
Merge remote-tracking branch 'origin/main' into etiotto.merge_load_wi…
etiotto 617dc0d
Allow trans user to be any operation as long as def-use chain end is …
etiotto dd8979d
Address code review comments
etiotto d3cb92b
Address code review comments
etiotto c1a6949
Address code review comments
etiotto e344c13
Allow trans user to be any operation as long as def-use chain end is …
etiotto e7d0d74
Merge remote-tracking branch 'origin/main' into etiotto.merge_load_wi…
etiotto 8e6ee3e
Merge remote-tracking branch 'origin/main' into etiotto.merge_load_wi…
etiotto 53b26ec
Fix precommit
etiotto 23a1ef5
Enable tutorial 06 with tt.trans when data type is not fp8
etiotto ff7a5a8
Merge remote-tracking branch 'origin/main' into etiotto.merge_load_wi…
etiotto 79ee3e6
Merge remote-tracking branch 'origin/main' into etiotto.merge_load_wi…
etiotto 22db300
Address code review comments
etiotto 0745f28
Address code review comments
etiotto 0d68f06
Address code review comments
etiotto File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -70,8 +70,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { | |
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { | ||
// COM: tt.load -> tt.trans -> tt.dot chain, in a loop. | ||
// COM: where the 'make_tensor_ptr' result is loop carried. | ||
tt.func public @fuseLoadWithTrans3(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) { | ||
%c4_i32 = arith.constant 4 : i32 | ||
tt.func public @fuseLoadWithTrans3(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | ||
%c1024_i32 = arith.constant 1024 : i32 | ||
%c0_i32 = arith.constant 0 : i32 | ||
%c32_i32 = arith.constant 32 : i32 | ||
|
@@ -80,15 +79,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { | |
%c1_i64 = arith.constant 1 : i64 | ||
%c1024_i64 = arith.constant 1024 : i64 | ||
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> | ||
%0 = tt.get_program_id x : i32 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just making the test simpler here |
||
%1 = arith.divsi %0, %c16_i32 : i32 | ||
%2 = arith.muli %1, %c4_i32 : i32 | ||
%3 = arith.subi %c4_i32, %2 : i32 | ||
%4 = arith.minsi %3, %c4_i32 : i32 | ||
%5 = arith.remsi %0, %c16_i32 : i32 | ||
%6 = arith.remsi %5, %4 : i32 | ||
%7 = arith.addi %2, %6 : i32 | ||
%8 = arith.divsi %5, %4 : i32 | ||
%9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> | ||
%10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #linear>> | ||
%13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>) : i32 { | ||
|
@@ -116,13 +106,61 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { | |
|
||
// ----- | ||
|
||
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [0, 32]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[0, 0], [0, 0]], block = []}> | ||
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 1], repCluster = [2, 2], A = [16, 16], B = [16, 32], C = [16, 32]}> | ||
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} { | ||
// COM: tt.load -> tt.trans -> tt.dot chain, in a 2 loops. | ||
etiotto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// COM: Where the block ptr used by the loads in the 2 loops is created by the same the make_tensor_ptr operation. | ||
etiotto marked this conversation as resolved.
Show resolved
Hide resolved
|
||
tt.func public @fuseLoadWithTrans4(%arg0: i32, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>) { | ||
%c32_i32 = arith.constant 32 : i32 | ||
%c0_i32 = arith.constant 0 : i32 | ||
%c64_i64 = arith.constant 64 : i64 | ||
%c1_i64 = arith.constant 1 : i64 | ||
%cst_3 = arith.constant dense<0.000000e+00> : tensor<64x32xf32, #mma> | ||
%7 = tt.make_tensor_ptr %arg1, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> | ||
%9 = tt.make_tensor_ptr %arg2, [%c1_i64, %c64_i64], [%c64_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #linear>> | ||
%24 = tt.advance %7, [%arg0, %c0_i32] : <tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> | ||
%25 = tt.load %24 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> | ||
%29:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg13 = %arg0) -> (i32) : i32 { | ||
%adv1 = tt.advance %9, [%arg13, %c0_i32] : <tensor<32x64xf16, #linear>> | ||
%load1 = tt.load %adv1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #linear>> | ||
%trans1 = tt.trans %load1 {order = array<i32: 1, 0>} : tensor<32x64xf16, #linear> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | ||
%dot1 = tt.dot %25, %trans1, %cst_3, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma> | ||
%76 = arith.addi %arg13, %c32_i32 : i32 | ||
scf.yield %76 : i32 | ||
} | ||
%38:1 = scf.for %arg9 = %c0_i32 to %arg0 step %c32_i32 iter_args(%arg13 = %arg0) -> (i32) : i32 { | ||
%adv2 = tt.advance %9, [%arg13, %c0_i32] : <tensor<32x64xf16, #linear>> | ||
%load2 = tt.load %adv2 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #linear>> | ||
%trans2 = tt.trans %load2 {order = array<i32: 1, 0>} : tensor<32x64xf16, #linear> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> | ||
%dot2 = tt.dot %25, %trans2, %cst_3, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma> | ||
%81 = arith.addi %arg13, %c32_i32 : i32 | ||
scf.yield %81 : i32 | ||
} | ||
tt.return | ||
} | ||
// CHECK-LABEL: fuseLoadWithTrans4 | ||
// CHECK-NOT: tt.trans | ||
// CHECK-COUNT-2: tt.make_tensor_ptr %arg2, [%c64_i64, %c1_i64], [%c1_i64, %c64_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> | ||
// CHECK: scf.for {{.*}} | ||
// CHECK: [[ADV1:%.*]] = tt.advance {{.*}}, {{.*}} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> | ||
// CHECK: [[LOAD_B1:%.*]] = tt.load [[ADV1]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> | ||
// CHECK: tt.dot {{.*}}, [[LOAD_B1]], {{.*}}, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma> | ||
// CHECK: scf.yield {{.*}} | ||
// CHECK: scf.for {{.*}} | ||
// CHECK: [[ADV2:%.*]] = tt.advance {{.*}}, {{.*}} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> | ||
// CHECK: [[LOAD_B2:%.*]] = tt.load [[ADV2]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "column_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>> | ||
// CHECK: tt.dot {{.*}}, [[LOAD_B2]], {{.*}}, inputPrecision = tf32 : tensor<64x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>> * tensor<64x32xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<64x32xf32, #mma> | ||
// CHECK: scf.yield {{.*}} | ||
} | ||
|
||
// ----- | ||
|
||
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [0, 4], [0, 8], [16, 0], [0, 16], [128, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0]], warp = [[32, 0], [64, 0], [0, 0], [0, 0], [0, 0]], block = []}> | ||
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}> | ||
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { | ||
// COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load | ||
// COM: that 'feeds' the transpose operation is used. | ||
tt.func public @doNotFuseLoadWithTrans1(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) { | ||
%c4_i32 = arith.constant 4 : i32 | ||
// COM: Ensure load is not fused with transpose if the loop result corresponding to the pointer used by the load that 'feeds' the transpose operation is used. | ||
tt.func public @doNotFuseLoadWithTrans1(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | ||
%c1024_i32 = arith.constant 1024 : i32 | ||
%c0_i32 = arith.constant 0 : i32 | ||
%c32_i32 = arith.constant 32 : i32 | ||
|
@@ -131,15 +169,6 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { | |
%c1_i64 = arith.constant 1 : i64 | ||
%c1024_i64 = arith.constant 1024 : i64 | ||
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #mma> | ||
%0 = tt.get_program_id x : i32 | ||
%1 = arith.divsi %0, %c16_i32 : i32 | ||
%2 = arith.muli %1, %c4_i32 : i32 | ||
%3 = arith.subi %c4_i32, %2 : i32 | ||
%4 = arith.minsi %3, %c4_i32 : i32 | ||
%5 = arith.remsi %0, %c16_i32 : i32 | ||
%6 = arith.remsi %5, %4 : i32 | ||
%7 = arith.addi %2, %6 : i32 | ||
%8 = arith.divsi %5, %4 : i32 | ||
%9 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c1024_i64], [%c1024_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 1}>>> | ||
%10 = tt.make_tensor_ptr %arg1, [%c1024_i64, %c1_i64], [%c1_i64, %c1024_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #linear>> | ||
%13:3 = scf.for %arg3 = %c0_i32 to %c1024_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %c0_i32, %arg6 = %10) -> (tensor<256x256xf32, #mma>, i32, !tt.ptr<tensor<256x32xbf16, #linear>>) : i32 { | ||
|
@@ -166,7 +195,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { | |
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 32 : i32} { | ||
// COM: Ensure load is not fused with transpose if there are multiple users of an operation in the def-use chain containing the load + transpose. | ||
// COM: In this case `%19` is used by the load that feeds the transpose and by a store operation. | ||
tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>, %arg2: !tt.ptr<f32>) { | ||
tt.func public @doNotFuseLoadWithTrans2(%arg0: !tt.ptr<bf16>, %arg1: !tt.ptr<bf16>) { | ||
%c4_i32 = arith.constant 4 : i32 | ||
%c1024_i32 = arith.constant 1024 : i32 | ||
%c0_i32 = arith.constant 0 : i32 | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For fp16 we undo the source code changes we made and the code is now back to the original. For FP8 we keep the source code changes until we can issue DPAS instructions for them (after making 2 fp8 elems into a fp16).