Skip to content

Commit 1dbef57

Browse files
authored
Make isExpensiveLoadOrStore consider blocked pointers load and stores (#2570)
The `isExpensiveLoadOrStore` function (third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp) fails to consider block pointers and consequently always returns `false` for loads (and stores) operations that use a block pointer. In turn, this causes the `RemoveLayoutConversion` pass to never consider loads using block pointers as `anchor` operations. This PR changes `isExpensiveLoadOrStore` so that block pointer loads can be properly recognized. The `RemoveLayourConversion` pass is then able to consider those loads as anchor operations and preserve their layout. Because `RemoveLayoutConversion` is invoked at several points in the optimization pipeline, the change in third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp alone causes performance degradation in a couple of GEMM like benchmarks, specifically when operand A of `tl.dot` is transposed and when the input of `tl.dot` is first fed into an exponential. These 2 performance degradation have ben fixed by an enhancing the `MaterializeBlockPointer` and `MatmulLoopPipeline` optimizations, so that they can retrieve the dot layout of block pointer loads transitively from its users (in those benchmarks the blocked layout of block ptrs loads is transitively converted to a dot layout). --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 55702d9 commit 1dbef57

File tree

5 files changed

+102
-49
lines changed

5 files changed

+102
-49
lines changed

test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -47,16 +47,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war
4747
// CHECK: %[[VAL_40:.*]] = tt.make_tensor_ptr %{{.*}}, {{\[}}%{{.*}}, %{{.*}}], {{\[}}%{{.*}}, %{{.*}}], {{\[}}%{{.*}}, %{{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
4848
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
4949
// CHECK: %[[VAL_41:.*]]:3 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %[[VAL_36]], %{{.*}} = %[[VAL_40]]) -> (tensor<64x256xf32, #[[DPAS]]>, !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>) : i32 {
50-
// CHECK: %[[VAL_46:.*]] = tt.load %{{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
51-
// CHECK: %[[VAL_47:.*]] = tt.load %{{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
50+
// CHECK: %[[VAL_46:.*]] = tt.load %{{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
51+
// CHECK: %[[VAL_47:.*]] = tt.load %{{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
5252
// CHECK-NOT: triton_gpu.convert_layout
5353
// CHECK-NEXT: %[[VAL_48:.*]] = tt.dot %[[VAL_46]], %[[VAL_47]], %{{.*}}, inputPrecision = tf32 : tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<64x256xf32, #[[DPAS]]>
5454
// CHECK: %[[VAL_49:.*]] = tt.advance %{{.*}}, {{\[}}%{{.*}}, %{{.*}}] : <tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
5555
// CHECK: %[[VAL_50:.*]] = tt.advance %{{.*}}, {{\[}}%{{.*}}, %{{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
5656
// CHECK: scf.yield %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x256xf32, #[[DPAS]]>, !tt.ptr<tensor<64x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
5757
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
58-
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
59-
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
58+
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #blocked>>
59+
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
6060
%30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
6161
%31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
6262
%32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
@@ -130,7 +130,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
130130
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
131131
}
132132
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
133-
// CHECK-NOT: triton_gpu.convert_layout
134133
%25 = triton_gpu.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1>
135134
%26 = arith.extsi %arg8 : i32 to i64
136135
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
@@ -147,6 +146,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
147146
// COM: Checks that DPAS encoding has been forwarded to the store op
148147
// COM: The `tt.make_tensor_ptr` has multiple users (the storeOp + another OP)
149148
// COM: The initial `tt.make_tensor_ptr` with non-DPAS encoding must be kept.
149+
// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
150150
// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
151151
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
152152
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
@@ -188,8 +188,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
188188
%21 = arith.extsi %arg7 : i32 to i64
189189
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
190190
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
191-
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
192-
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
191+
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #blocked>>
192+
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
193193
%30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
194194
%31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
195195
%32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
@@ -198,11 +198,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
198198
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
199199
}
200200
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
201-
// CHECK-NOT: triton_gpu.convert_layout
202201
%25 = triton_gpu.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1>
203202
%26 = arith.extsi %arg8 : i32 to i64
204203
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[DPAS]]>>
205-
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
204+
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[BLOCKED]]>>
206205
%27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #blocked1>>
207206
// CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[DPAS]]>>
208207
tt.store %27, %25 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #blocked1>>
@@ -243,8 +242,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
243242
%18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
244243
%22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
245244
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #blocked1>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
246-
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>
247-
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
245+
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
246+
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
248247
%36 = triton_gpu.convert_layout %arg10 : tensor<64x256xf32, #blocked1> -> tensor<64x256xf32, #dpas>
249248
%30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
250249
%31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>

test/TritonIntelGPU/combine.mlir

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2324,31 +2324,29 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32
23242324
%cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked2>
23252325
%0 = tt.get_program_id x : i32
23262326
%1 = tt.get_program_id y : i32
2327-
// CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>
2328-
// CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
2327+
// CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<256x32xbf16, {{.*}}>>
2328+
// CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : <tensor<32x256xbf16, {{.*}}>>
23292329
%12 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<256x32xbf16, #blocked3>>
23302330
%14 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array<i32: 1, 0>} : <tensor<32x256xbf16, #blocked2>>
2331-
// CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>) : i32 {
2331+
// CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, {{.*}}>>, !tt.ptr<tensor<32x256xbf16, {{.*}}>>) : i32 {
23322332
%15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg4 = %cst_1, %arg5 = %12, %arg6 = %14) -> (tensor<256x256xf32, #blocked2>, !tt.ptr<tensor<256x32xbf16, #blocked3>>, !tt.ptr<tensor<32x256xbf16, #blocked2>>) : i32 {
23332333
%47 = tt.load %arg5 : !tt.ptr<tensor<256x32xbf16, #blocked3>>
23342334
%48 = tt.load %arg6 : !tt.ptr<tensor<32x256xbf16, #blocked2>>
2335-
// CHEKC-NOT: triton_gpu.convert_layout
23362335
%49 = triton_gpu.convert_layout %arg4 : tensor<256x256xf32, #blocked2> -> tensor<256x256xf32, #mma>
23372336
%50 = triton_gpu.convert_layout %47 : tensor<256x32xbf16, #blocked3> -> tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>>
23382337
%51 = triton_gpu.convert_layout %48 : tensor<32x256xbf16, #blocked2> -> tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>>
23392338
%52 = tt.dot %50, %51, %49, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma>
23402339
%53 = triton_gpu.convert_layout %52 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked2>
2341-
// CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : <tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>
2342-
// CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : <tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
2343-
// CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 2}>>>, !tt.ptr<tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
2340+
// CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : <tensor<256x32xbf16, {{.*}}>>
2341+
// CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : <tensor<32x256xbf16, {{.*}}>>
2342+
// CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr<tensor<256x32xbf16, {{.*}}>>, !tt.ptr<tensor<32x256xbf16, {{.*}}>>
23442343
%54 = tt.advance %arg5, [%c0_i32, %c128_i32] : <tensor<256x32xbf16, #blocked3>>
23452344
%55 = tt.advance %arg6, [%c128_i32, %c0_i32] : <tensor<32x256xbf16, #blocked2>>
23462345
scf.yield %53, %54, %55 : tensor<256x256xf32, #blocked2>, !tt.ptr<tensor<256x32xbf16, #blocked3>>, !tt.ptr<tensor<32x256xbf16, #blocked2>>
23472346
}
23482347
%16 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked>
23492348
%32 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x256x!tt.ptr<f32>, #blocked2>
23502349
%38 = arith.cmpi slt, %16, %cst : tensor<256xi32, #blocked>
2351-
// CHEKC-NOT: triton_gpu.convert_layout
23522350
%39 = triton_gpu.convert_layout %38 : tensor<256xi1, #blocked> -> tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>>
23532351
%40 = tt.expand_dims %39 {axis = 0 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> -> tensor<1x256xi1, #blocked4>
23542352
%41 = triton_gpu.convert_layout %40 : tensor<1x256xi1, #blocked4> -> tensor<1x256xi1, #blocked2>

third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@
44
#include "mlir/Dialect/Arith/IR/Arith.h"
55
#include "mlir/IR/Visitors.h"
66
#include "triton/Analysis/Utility.h"
7+
#include "llvm/Support/Casting.h"
78
#include "llvm/Support/Debug.h"
9+
#include <optional>
810

911
#define DEBUG_TYPE "tritonintelgpu-materialize-block-pointer"
1012
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
1113
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
1214

1315
using namespace mlir;
1416
namespace tt = mlir::triton;
17+
namespace ttg = mlir::triton::gpu;
1518
namespace ttgi = mlir::triton::gpu::intel;
1619

1720
namespace mlir::triton::gpu::intel {
@@ -37,7 +40,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass
3740
return;
3841

3942
MLIRContext *context = &getContext();
40-
mod.walk([context](tt::LoadOp loadOp) {
43+
mod.walk([context, this](tt::LoadOp loadOp) {
4144
LDBG("Considering op: " << loadOp);
4245

4346
Value ptr = loadOp.getPtr();
@@ -51,7 +54,6 @@ struct TritonIntelGPUMaterializeBlockPointerPass
5154
LDBG("Found make tensor ptr op: " << makeTensorPtrOp);
5255
auto ptrType = cast<tt::PointerType>(makeTensorPtrOp.getType());
5356
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
54-
auto dotLayout = ttgi::getDotEncoding(tensorType);
5557

5658
Operation::operand_range shape = makeTensorPtrOp.getShape();
5759
unsigned rank = shape.size();
@@ -100,11 +102,13 @@ struct TritonIntelGPUMaterializeBlockPointerPass
100102
return;
101103

102104
const bool isRowMajor = fastChangeDim == rank - 1;
105+
std::optional<ttg::DotOperandEncodingAttr> dotLayout =
106+
getDotLayout(loadOp);
103107
if (dotLayout) {
104-
// Check if the load is being used in a dot layout, and if so is this
105-
// the first op and is it a transposed row major matrix. If so, skip
106-
// the block ptr attribute as performance is worse than if we remove
107-
// the tensor pointer
108+
// Check if the load is being used by a tt.dot operation, and if so is
109+
// this the first operand and is it a transposed row major matrix. If
110+
// so, skip the block ptr attribute as performance is worse than if we
111+
// remove the tensor pointer.
108112
LDBG("dotLayout: " << *dotLayout);
109113
const unsigned opIdx = dotLayout->getOpIdx();
110114
auto dotOrder = dotLayout->getThreadOrder();
@@ -122,6 +126,52 @@ struct TritonIntelGPUMaterializeBlockPointerPass
122126
}
123127
});
124128
}
129+
130+
private:
131+
// Return the load layout if it is a dot layout. If it is not, check if the
132+
// load result is converted to a dot layout. If so, return the dot layout,
133+
// otherwise return nullopt.
134+
std::optional<ttg::DotOperandEncodingAttr>
135+
getDotLayout(tt::LoadOp loadOp) const {
136+
Value ptr = loadOp.getPtr();
137+
if (!tt::isTensorPointerType(ptr.getType()))
138+
return std::nullopt;
139+
140+
RankedTensorType tensorType = ttgi::getRankedTensorType(ptr.getType());
141+
if (!tensorType)
142+
return std::nullopt;
143+
144+
auto dotLayout = ttgi::getDotEncoding(tensorType);
145+
if (dotLayout)
146+
return dotLayout;
147+
148+
auto allUsersAreConvertOps = [](Operation::user_range users) {
149+
return llvm::all_of(users, [](Operation *user) {
150+
return isa<ttg::ConvertLayoutOp>(user);
151+
});
152+
};
153+
154+
auto allUserHaveIdenticalLayout = [](Operation::user_range users) {
155+
Attribute firstUserLayout =
156+
cast<ttg::ConvertLayoutOp>(*users.begin()).getType().getEncoding();
157+
return llvm::all_of(users, [&firstUserLayout](Operation *user) {
158+
return firstUserLayout ==
159+
cast<ttg::ConvertLayoutOp>(user).getType().getEncoding();
160+
});
161+
};
162+
163+
Operation::user_range users = loadOp->getUsers();
164+
if (!users.empty() && allUsersAreConvertOps(users) &&
165+
allUserHaveIdenticalLayout(users)) {
166+
Attribute firstUserLayout =
167+
cast<ttg::ConvertLayoutOp>(*users.begin()).getType().getEncoding();
168+
if (isa<ttg::DotOperandEncodingAttr>(firstUserLayout))
169+
return dyn_cast<ttg::DotOperandEncodingAttr>(firstUserLayout);
170+
return std::nullopt;
171+
}
172+
173+
return std::nullopt;
174+
}
125175
};
126176

127177
} // anonymous namespace

0 commit comments

Comments
 (0)