Skip to content
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
c7fe682
Improve axis analysis to handle tt.make_tensor_ptr
etiotto Oct 9, 2024
ad3888f
Merge branch 'main' into etiotto/axis_analysis_make_tensor_ptr
etiotto Oct 9, 2024
a7a9b06
Merge branch 'main' into etiotto/axis_analysis_make_tensor_ptr
etiotto Oct 10, 2024
6bddd5f
Merge branch 'main' into etiotto/axis_analysis_make_tensor_ptr
etiotto Oct 10, 2024
4ad4f1a
Merge branch 'main' into etiotto/axis_analysis_make_tensor_ptr
etiotto Oct 10, 2024
4dc1cf1
WIP: Coalescing for block ptrs
etiotto Oct 16, 2024
fa53ced
Fix pre_commit
etiotto Oct 16, 2024
049ddb8
Merge branch 'main' into etiotto/coalesce_for_block_ptr
etiotto Oct 17, 2024
041e2da
Merge branch 'main' into etiotto/coalesce_for_block_ptr
etiotto Oct 17, 2024
5a6cf81
Fix functional problem and add lit test
etiotto Oct 17, 2024
2546665
Fix pre_commit
etiotto Oct 17, 2024
4d5dc49
Reenable rewrite tensor ptr
etiotto Oct 17, 2024
c3fdbba
Fix test_core regression
etiotto Oct 18, 2024
d9de8e7
Fix tutorial assertion
etiotto Oct 18, 2024
949256e
Refactor
etiotto Oct 18, 2024
754ec70
Cleanup
etiotto Oct 18, 2024
469407b
Cleanup
etiotto Oct 18, 2024
9f4f98d
Extend axis info analysis to more block ptrs
etiotto Oct 21, 2024
a40844b
Merge branch 'main' into etiotto/coalesce_for_block_ptr
etiotto Oct 21, 2024
bb9b4c3
Address code review comments
etiotto Oct 22, 2024
8d9a158
Remove unrelated change
etiotto Oct 22, 2024
6529f04
Remove unrelated change
etiotto Oct 22, 2024
0aa334b
Remove unrelated change
etiotto Oct 22, 2024
547d6fa
Fix pre_commit
etiotto Oct 22, 2024
6566f6c
Merge branch 'main' into etiotto/coalesce_for_block_ptr
etiotto Oct 23, 2024
2f97c1a
Address code review comments
etiotto Oct 23, 2024
95f5832
Fix pre_commit
etiotto Oct 23, 2024
0887245
Merge branch 'main' into etiottoremove_layout_conv
etiotto Oct 24, 2024
3636bef
Make isExpensiveLoadOrStore consider blocked pointers load and stores
etiotto Oct 24, 2024
db2193e
Make isExpensiveLoadOrStore consider blocked pointers load and stores
etiotto Oct 25, 2024
eeda8e9
Merge branch 'main' into etiottoremove_layout_conv
etiotto Oct 25, 2024
7c9a0f9
MaterializeBlockPointer fix for GEMM with 1st operand transposed
etiotto Oct 25, 2024
cbc630b
MaterializeBlockPointer fix for GEMM with 1st operand transposed
etiotto Oct 25, 2024
0215a16
Fix unit tests
etiotto Oct 28, 2024
ae3d625
Fix performance regression for gemm-preop-exp
etiotto Oct 28, 2024
22b7ec9
Reduce PR footprint
etiotto Oct 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
// CHECK: %[[VAL_40:.*]] = tt.make_tensor_ptr %{{.*}}, {{\[}}%{{.*}}, %{{.*}}], {{\[}}%{{.*}}, %{{.*}}], {{\[}}%{{.*}}, %{{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
// CHECK: %[[VAL_41:.*]]:3 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %[[VAL_36]], %{{.*}} = %[[VAL_40]]) -> (tensor<64x256xf32, #[[DPAS]]>, !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>) : i32 {
// CHECK: %[[VAL_46:.*]] = tt.load %{{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_47:.*]] = tt.load %{{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_46:.*]] = tt.load %{{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note: adding triton_intel_gpu.block_io is consistent with our optimization pipeline (in our pipeline this is done before the 2nd invocation of RemoveLayoutConversion)

// CHECK: %[[VAL_47:.*]] = tt.load %{{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK-NOT: triton_gpu.convert_layout
// CHECK-NEXT: %[[VAL_48:.*]] = tt.dot %[[VAL_46]], %[[VAL_47]], %{{.*}}, inputPrecision = tf32 : tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<64x256xf32, #[[DPAS]]>
// CHECK: %[[VAL_49:.*]] = tt.advance %{{.*}}, {{\[}}%{{.*}}, %{{.*}}] : <tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_50:.*]] = tt.advance %{{.*}}, {{\[}}%{{.*}}, %{{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: scf.yield %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x256xf32, #[[DPAS]]>, !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
%31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
%32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
Expand Down Expand Up @@ -130,7 +130,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
}
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
// CHECK-NOT: triton_gpu.convert_layout
%25 = triton_gpu.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1>
%26 = arith.extsi %arg8 : i32 to i64
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
Expand All @@ -147,6 +146,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
// COM: Checks that DPAS encoding has been forwarded to the store op
// COM: The `tt.make_tensor_ptr` has multiple users (the storeOp + another OP)
// COM: The initial `tt.make_tensor_ptr` with non-DPAS encoding must be kept.
// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
Expand Down Expand Up @@ -188,8 +188,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%21 = arith.extsi %arg7 : i32 to i64
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
%31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
%32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
Expand All @@ -198,11 +198,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
}
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
// CHECK-NOT: triton_gpu.convert_layout
%25 = triton_gpu.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1>
%26 = arith.extsi %arg8 : i32 to i64
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[BLOCKED]]>>
%27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #blocked1>>
// CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
tt.store %27, %25 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #blocked1>>
Expand Down Expand Up @@ -243,8 +242,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
%18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
%22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #blocked1>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%36 = triton_gpu.convert_layout %arg10 : tensor<64x256xf32, #blocked1> -> tensor<64x256xf32, #dpas>
%30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
%31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
Expand Down
14 changes: 6 additions & 8 deletions test/TritonIntelGPU/combine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2324,31 +2324,29 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
%cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked2>
%0 = tt.get_program_id x : i32
%1 = tt.get_program_id y : i32
// CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<256x32xbf16, {{.*}}>>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual layout is not important in these tests.

// CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x256xbf16, {{.*}}>>
%12 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #blocked3>>
%14 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<32x256xbf16, #blocked2>>
// CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>) : i32 {
// CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, {{.*}}>>, !tt.ptr<tensor<32x256xbf16, {{.*}}>>) : i32 {
%15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg4 = %cst_1, %arg5 = %12, %arg6 = %14) -> (tensor<256x256xf32, #blocked2>, !tt.ptr<tensor<256x32xbf16, #blocked3>>, !tt.ptr<tensor<32x256xbf16, #blocked2>>) : i32 {
%47 = tt.load %arg5 : !tt.ptr<tensor<256x32xbf16, #blocked3>>
%48 = tt.load %arg6 : !tt.ptr<tensor<32x256xbf16, #blocked2>>
// CHEKC-NOT: triton_gpu.convert_layout
%49 = triton_gpu.convert_layout %arg4 : tensor<256x256xf32, #blocked2> -> tensor<256x256xf32, #mma>
%50 = triton_gpu.convert_layout %47 : tensor<256x32xbf16, #blocked3> -> tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
%51 = triton_gpu.convert_layout %48 : tensor<32x256xbf16, #blocked2> -> tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
%52 = tt.dot %50, %51, %49, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
%53 = triton_gpu.convert_layout %52 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked2>
// CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
// CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : <tensor<256x32xbf16, {{.*}}>>
// CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : <tensor<32x256xbf16, {{.*}}>>
// CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, {{.*}}>>, !tt.ptr<tensor<32x256xbf16, {{.*}}>>
%54 = tt.advance %arg5, [%c0_i32, %c128_i32] : <tensor<256x32xbf16, #blocked3>>
%55 = tt.advance %arg6, [%c128_i32, %c0_i32] : <tensor<32x256xbf16, #blocked2>>
scf.yield %53, %54, %55 : tensor<256x256xf32, #blocked2>, !tt.ptr<tensor<256x32xbf16, #blocked3>>, !tt.ptr<tensor<32x256xbf16, #blocked2>>
}
%16 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
%32 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x256x!tt.ptr<f32>, #blocked2>
%38 = arith.cmpi slt, %16, %cst : tensor<256xi32, #blocked>
// CHEKC-NOT: triton_gpu.convert_layout
%39 = triton_gpu.convert_layout %38 : tensor<256xi1, #blocked> -> tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>
%40 = tt.expand_dims %39 {axis = 0 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> -> tensor<1x256xi1, #blocked4>
%41 = triton_gpu.convert_layout %40 : tensor<1x256xi1, #blocked4> -> tensor<1x256xi1, #blocked2>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,17 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/IR/Visitors.h"
#include "triton/Analysis/Utility.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/Debug.h"
#include <optional>

#define DEBUG_TYPE "tritonintelgpu-materialize-block-pointer"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

using namespace mlir;
namespace tt = mlir::triton;
namespace ttg = mlir::triton::gpu;
namespace ttgi = mlir::triton::gpu::intel;

namespace mlir::triton::gpu::intel {
Expand All @@ -37,7 +40,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass
return;

MLIRContext *context = &getContext();
mod.walk([context](tt::LoadOp loadOp) {
mod.walk([context, this](tt::LoadOp loadOp) {
LDBG("Considering op: " << loadOp);

Value ptr = loadOp.getPtr();
Expand All @@ -51,7 +54,6 @@ struct TritonIntelGPUMaterializeBlockPointerPass
LDBG("Found make tensor ptr op: " << makeTensorPtrOp);
auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType());
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
auto dotLayout = ttgi::getDotEncoding(tensorType);

Operation::operand_range shape = makeTensorPtrOp.getShape();
unsigned rank = shape.size();
Expand Down Expand Up @@ -100,11 +102,13 @@ struct TritonIntelGPUMaterializeBlockPointerPass
return;

const bool isRowMajor = fastChangeDim == rank - 1;
std::optional<ttg::DotOperandEncodingAttr> dotLayout =
getDotLayout(loadOp);
if (dotLayout) {
// Check if the load is being used in a dot layout, and if so is this
// the first op and is it a transposed row major matrix. If so, skip
// the block ptr attribute as performance is worse than if we remove
// the tensor pointer
// Check if the load is being used by a tt.dot operation, and if so is
// this the first operand and is it a transposed row major matrix. If
// so, skip the block ptr attribute as performance is worse than if we
// remove the tensor pointer.
LDBG("dotLayout: " << *dotLayout);
const unsigned opIdx = dotLayout->getOpIdx();
auto dotOrder = dotLayout->getThreadOrder();
Expand All @@ -122,6 +126,52 @@ struct TritonIntelGPUMaterializeBlockPointerPass
}
});
}

private:
// Return the load layout if it is a dot layout. If it is not, check if the
// load result is converted to a dot layout. If so, return the dot layout,
// otherwise return nullopt.
std::optional<ttg::DotOperandEncodingAttr>
getDotLayout(tt::LoadOp loadOp) const {
Value ptr = loadOp.getPtr();
if (!tt::isTensorPointerType(ptr.getType()))
return std::nullopt;

RankedTensorType tensorType = ttgi::getRankedTensorType(ptr.getType());
if (!tensorType)
return std::nullopt;

auto dotLayout = ttgi::getDotEncoding(tensorType);
if (dotLayout)
return dotLayout;

auto allUsersAreConvertOps = [](Operation::user_range users) {
return llvm::all_of(users, [](Operation *user) {
return isa<ttg::ConvertLayoutOp>(user);
});
};

auto allUserHaveIdenticalLayout = [](Operation::user_range users) {
Attribute firstUserLayout =
cast<ttg::ConvertLayoutOp>(*users.begin()).getType().getEncoding();
return llvm::all_of(users, [&firstUserLayout](Operation *user) {
return firstUserLayout ==
cast<ttg::ConvertLayoutOp>(user).getType().getEncoding();
});
};

Operation::user_range users = loadOp->getUsers();
if (!users.empty() && allUsersAreConvertOps(users) &&
allUserHaveIdenticalLayout(users)) {
Attribute firstUserLayout =
cast<ttg::ConvertLayoutOp>(*users.begin()).getType().getEncoding();
if (isa<ttg::DotOperandEncodingAttr>(firstUserLayout))
return dyn_cast<ttg::DotOperandEncodingAttr>(firstUserLayout);
return std::nullopt;
}

return std::nullopt;
}
};

} // anonymous namespace
Loading