diff --git a/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir b/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir index d6f43af96d..cf4e3ad8de 100644 --- a/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir +++ b/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir @@ -47,16 +47,16 @@ module attributes {"triton_gpu.num-warps" = 4 : i32, "triton_gpu.threads-per-war // CHECK: %[[VAL_40:.*]] = tt.make_tensor_ptr %{{.*}}, {{\[}}%{{.*}}, %{{.*}}], {{\[}}%{{.*}}, %{{.*}}], {{\[}}%{{.*}}, %{{.*}}] {order = array} : >> %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > // CHECK: %[[VAL_41:.*]]:3 = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %{{.*}}, %{{.*}} = %[[VAL_36]], %{{.*}} = %[[VAL_40]]) -> (tensor<64x256xf32, #[[DPAS]]>, !tt.ptr>>, !tt.ptr>>) : i32 { - // CHECK: %[[VAL_46:.*]] = tt.load %{{.*}} {boundaryCheck = array} : !tt.ptr>> - // CHECK: %[[VAL_47:.*]] = tt.load %{{.*}} {boundaryCheck = array} : !tt.ptr>> + // CHECK: %[[VAL_46:.*]] = tt.load %{{.*}} {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr>> + // CHECK: %[[VAL_47:.*]] = tt.load %{{.*}} {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr>> // CHECK-NOT: triton_gpu.convert_layout // CHECK-NEXT: %[[VAL_48:.*]] = tt.dot %[[VAL_46]], %[[VAL_47]], %{{.*}}, inputPrecision = tf32 : tensor<64x32xf16, #{{.*}}<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<64x256xf32, #[[DPAS]]> // CHECK: %[[VAL_49:.*]] = tt.advance %{{.*}}, {{\[}}%{{.*}}, %{{.*}}] : >> // CHECK: %[[VAL_50:.*]] = tt.advance %{{.*}}, {{\[}}%{{.*}}, %{{.*}}] : >> // CHECK: scf.yield %{{.*}}, %{{.*}}, %{{.*}} : tensor<64x256xf32, #[[DPAS]]>, !tt.ptr>>, !tt.ptr>> %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { - %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> + %28 = tt.load %arg11 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %29 = tt.load %arg12 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0> %31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1> %32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas> @@ -130,7 +130,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr>, !tt.ptr> } %24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas> - // CHECK-NOT: triton_gpu.convert_layout %25 = triton_gpu.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1> %26 = arith.extsi %arg8 : i32 to i64 // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > @@ -147,6 +146,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : // COM: Checks that DPAS encoding has been forwarded to the store op // COM: The `tt.make_tensor_ptr` has multiple users (the storeOp + another OP) // COM: The initial `tt.make_tensor_ptr` with non-DPAS encoding must be kept. +// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> // CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> #blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> #blocked1 = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> @@ -188,8 +188,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %21 = arith.extsi %arg7 : i32 to i64 %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { - %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> + %28 = tt.load %arg11 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %29 = tt.load %arg12 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0> %31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1> %32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas> @@ -198,11 +198,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr>, !tt.ptr> } %24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas> - // CHECK-NOT: triton_gpu.convert_layout %25 = triton_gpu.convert_layout %24 : tensor<64x256xf16, #dpas> -> tensor<64x256xf16, #blocked1> %26 = arith.extsi %arg8 : i32 to i64 // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > - // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> + // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > %27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array} : > // CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array} : !tt.ptr> tt.store %27, %25 {boundaryCheck = array} : !tt.ptr> @@ -243,8 +242,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : %18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > %22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #blocked1>, !tt.ptr>, !tt.ptr>) : i32 { - %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> + %28 = tt.load %arg11 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major" } : !tt.ptr> + %29 = tt.load %arg12 {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> %36 = triton_gpu.convert_layout %arg10 : tensor<64x256xf32, #blocked1> -> tensor<64x256xf32, #dpas> %30 = triton_gpu.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0> %31 = triton_gpu.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1> diff --git a/test/TritonIntelGPU/combine.mlir b/test/TritonIntelGPU/combine.mlir index 64f3193653..318d957c63 100644 --- a/test/TritonIntelGPU/combine.mlir +++ b/test/TritonIntelGPU/combine.mlir @@ -2324,23 +2324,22 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 %cst_1 = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #blocked2> %0 = tt.get_program_id x : i32 %1 = tt.get_program_id y : i32 - // CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : >> - // CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : >> + // CHECK: %[[VAL_0:.*]] = tt.make_tensor_ptr {{.*}} : > + // CHECK: %[[VAL_1:.*]] = tt.make_tensor_ptr {{.*}} : > %12 = tt.make_tensor_ptr %arg0, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array} : > %14 = tt.make_tensor_ptr %arg1, [%c4096_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%0, %1] {order = array} : > - // CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr>>, !tt.ptr>>) : i32 { + // CHECK: %[[VAL_2:.*]]:3 = scf.for {{.*}} -> (tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr>, !tt.ptr>) : i32 { %15:3 = scf.for %arg3 = %c0_i32 to %c4096_i32 step %c128_i32 iter_args(%arg4 = %cst_1, %arg5 = %12, %arg6 = %14) -> (tensor<256x256xf32, #blocked2>, !tt.ptr>, !tt.ptr>) : i32 { %47 = tt.load %arg5 : !tt.ptr> %48 = tt.load %arg6 : !tt.ptr> - // CHEKC-NOT: triton_gpu.convert_layout %49 = triton_gpu.convert_layout %arg4 : tensor<256x256xf32, #blocked2> -> tensor<256x256xf32, #mma> %50 = triton_gpu.convert_layout %47 : tensor<256x32xbf16, #blocked3> -> tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> %51 = triton_gpu.convert_layout %48 : tensor<32x256xbf16, #blocked2> -> tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> %52 = tt.dot %50, %51, %49, inputPrecision = tf32 : tensor<256x32xbf16, #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 2}>> * tensor<32x256xbf16, #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 2}>> -> tensor<256x256xf32, #mma> %53 = triton_gpu.convert_layout %52 : tensor<256x256xf32, #mma> -> tensor<256x256xf32, #blocked2> - // CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : >> - // CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : >> - // CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr>>, !tt.ptr>> + // CHECK: %[[VAL_3:.*]] = tt.advance {{.*}} : > + // CHECK: %[[VAL_4:.*]] = tt.advance {{.*}} : > + // CHECK: scf.yield {{.*}} : tensor<256x256xf32, #[[$DPAS]]>, !tt.ptr>, !tt.ptr> %54 = tt.advance %arg5, [%c0_i32, %c128_i32] : > %55 = tt.advance %arg6, [%c128_i32, %c0_i32] : > scf.yield %53, %54, %55 : tensor<256x256xf32, #blocked2>, !tt.ptr>, !tt.ptr> @@ -2348,7 +2347,6 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 %16 = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked> %32 = tt.splat %arg2 : !tt.ptr -> tensor<256x256x!tt.ptr, #blocked2> %38 = arith.cmpi slt, %16, %cst : tensor<256xi32, #blocked> - // CHEKC-NOT: triton_gpu.convert_layout %39 = triton_gpu.convert_layout %38 : tensor<256xi1, #blocked> -> tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> %40 = tt.expand_dims %39 {axis = 0 : i32} : tensor<256xi1, #triton_gpu.slice<{dim = 0, parent = #blocked4}>> -> tensor<1x256xi1, #blocked4> %41 = triton_gpu.convert_layout %40 : tensor<1x256xi1, #blocked4> -> tensor<1x256xi1, #blocked2> diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp index 601e3694e9..f8f554bb02 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp @@ -4,7 +4,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/IR/Visitors.h" #include "triton/Analysis/Utility.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" +#include #define DEBUG_TYPE "tritonintelgpu-materialize-block-pointer" #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") @@ -12,6 +14,7 @@ using namespace mlir; namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; namespace ttgi = mlir::triton::gpu::intel; namespace mlir::triton::gpu::intel { @@ -37,7 +40,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass return; MLIRContext *context = &getContext(); - mod.walk([context](tt::LoadOp loadOp) { + mod.walk([context, this](tt::LoadOp loadOp) { LDBG("Considering op: " << loadOp); Value ptr = loadOp.getPtr(); @@ -51,7 +54,6 @@ struct TritonIntelGPUMaterializeBlockPointerPass LDBG("Found make tensor ptr op: " << makeTensorPtrOp); auto ptrType = cast(makeTensorPtrOp.getType()); auto tensorType = cast(ptrType.getPointeeType()); - auto dotLayout = ttgi::getDotEncoding(tensorType); Operation::operand_range shape = makeTensorPtrOp.getShape(); unsigned rank = shape.size(); @@ -100,11 +102,13 @@ struct TritonIntelGPUMaterializeBlockPointerPass return; const bool isRowMajor = fastChangeDim == rank - 1; + std::optional dotLayout = + getDotLayout(loadOp); if (dotLayout) { - // Check if the load is being used in a dot layout, and if so is this - // the first op and is it a transposed row major matrix. If so, skip - // the block ptr attribute as performance is worse than if we remove - // the tensor pointer + // Check if the load is being used by a tt.dot operation, and if so is + // this the first operand and is it a transposed row major matrix. If + // so, skip the block ptr attribute as performance is worse than if we + // remove the tensor pointer. LDBG("dotLayout: " << *dotLayout); const unsigned opIdx = dotLayout->getOpIdx(); auto dotOrder = dotLayout->getThreadOrder(); @@ -122,6 +126,52 @@ struct TritonIntelGPUMaterializeBlockPointerPass } }); } + +private: + // Return the load layout if it is a dot layout. If it is not, check if the + // load result is converted to a dot layout. If so, return the dot layout, + // otherwise return nullopt. + std::optional + getDotLayout(tt::LoadOp loadOp) const { + Value ptr = loadOp.getPtr(); + if (!tt::isTensorPointerType(ptr.getType())) + return std::nullopt; + + RankedTensorType tensorType = ttgi::getRankedTensorType(ptr.getType()); + if (!tensorType) + return std::nullopt; + + auto dotLayout = ttgi::getDotEncoding(tensorType); + if (dotLayout) + return dotLayout; + + auto allUsersAreConvertOps = [](Operation::user_range users) { + return llvm::all_of(users, [](Operation *user) { + return isa(user); + }); + }; + + auto allUserHaveIdenticalLayout = [](Operation::user_range users) { + Attribute firstUserLayout = + cast(*users.begin()).getType().getEncoding(); + return llvm::all_of(users, [&firstUserLayout](Operation *user) { + return firstUserLayout == + cast(user).getType().getEncoding(); + }); + }; + + Operation::user_range users = loadOp->getUsers(); + if (!users.empty() && allUsersAreConvertOps(users) && + allUserHaveIdenticalLayout(users)) { + Attribute firstUserLayout = + cast(*users.begin()).getType().getEncoding(); + if (isa(firstUserLayout)) + return dyn_cast(firstUserLayout); + return std::nullopt; + } + + return std::nullopt; + } }; } // anonymous namespace diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp index 7f5789db3a..26190b98c3 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Pipeliner/MatmulLoopPipeline.cpp @@ -10,6 +10,8 @@ #include "llvm/Support/Debug.h" #define DEBUG_TYPE "tritonintelgpu-pipeline" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") using namespace mlir; namespace tt = mlir::triton; @@ -55,30 +57,25 @@ static ttg::DotOperandEncodingAttr getDotEncodingFromUser(Operation *user) { if (!tensorType) return nullptr; - if (isa(tensorType.getEncoding())) - return allTransitiveUsesHaveDotEncoding(res); - - return llvm::dyn_cast_or_null( - tensorType.getEncoding()); + Attribute layout = tensorType.getEncoding(); + return isa(layout) + ? allTransitiveUsesHaveDotEncoding(res) + : llvm::dyn_cast_or_null(layout); } /// If all the transitive uses of the given value are used by a convert to the /// same dot operand encoding, return the encoding. Otherwise return nullptr. static ttg::DotOperandEncodingAttr allTransitiveUsesHaveDotEncoding(Value val) { ttg::DotOperandEncodingAttr attr{nullptr}; - LLVM_DEBUG(llvm::dbgs() << "Checking users of " << val << "\n"); + LDBG("Checking users of " << val); for (Operation *user : val.getUsers()) { - ttg::DotOperandEncodingAttr dotAttr; - if (isa(user)) { - auto tensorType = cast(val.getType()); - dotAttr = dyn_cast(tensorType.getEncoding()); - } else { - dotAttr = getDotEncodingFromUser(user); - } + ttg::DotOperandEncodingAttr dotAttr = + isa(user) + ? dyn_cast( + cast(val.getType()).getEncoding()) + : getDotEncodingFromUser(user); if (!dotAttr || (attr != nullptr && attr != dotAttr)) { - LLVM_DEBUG({ - llvm::dbgs() << "no dot attribute found for user: " << user << "\n"; - }); + LDBG("no dot attribute found for user: " << *user); return nullptr; } attr = dotAttr; @@ -292,14 +289,14 @@ bool ttgi::preProcessLoopAndGetSchedule(scf::ForOp &forOp, int numStages, SmallVector loads; collectOpsToPipeline(forOp, loads, supportRegularPtr); if (loads.empty()) { - LLVM_DEBUG(llvm::dbgs() << "No loads to pipeline\n"); + LDBG("No loads to pipeline"); return false; } LLVM_DEBUG({ - llvm::dbgs() << "Loads to pipeline:\n"; + DBGS() << "Loads to pipeline:\n"; for (const LoadDotOperand &load : loads) - llvm::dbgs() << " " << *load.load << "\n"; + DBGS() << " " << *load.load << "\n"; }); // 2. Create the prefetching operations for the loads collected. diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index 759fc1782d..7fe296695a 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -11,6 +11,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Attributes.h" +#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" #include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" #include "triton/Dialect/TritonGPU/IR/Dialect.h" @@ -85,14 +86,22 @@ bool isExpensiveLoadOrStore(Operation *op) { "Expecting Triton LoadOp or StoreOp"); Value base = op->getOperand(0); - // Case 1: A size 1 tensor is not expensive since all threads will load the - // same + // A size 1 tensor is not expensive since all threads will load the same + // value. if (isSingleValue(base)) return false; - // Case 2: Tensor of pointers has more threads than elements - // we can presume a high hit-rate that makes it cheap to load - if (auto ptrType = dyn_cast(base.getType())) { + // Loads that use a block pointer are expensive if they cannot be lowered to + // 2D block read operations. Temporarily leverage the + // "triton_intel_gpu.block_io" attribute to filter out inexpensive loads. + Attribute blockIOAttr = + op->getAttr(TritonIntelGPUDialect::getBlockIOAttrName()); + if (blockIOAttr) + return false; + + // Loads that use more threads than elements can be presumed to have a high + // hit-rate that makes them cheap to load. + if (auto ptrType = getRankedTensorType(base.getType())) { auto mod = op->getParentOfType(); int numWarps = ttg::TritonGPUDialect::getNumWarps(mod); int threadsPerWarp = ttg::TritonGPUDialect::getThreadsPerWarp(mod);