diff --git a/test/TritonIntelGPU/materialize-block-pointer.mlir b/test/TritonIntelGPU/materialize-block-pointer.mlir index f2fc81d71e..71596c3ccc 100644 --- a/test/TritonIntelGPU/materialize-block-pointer.mlir +++ b/test/TritonIntelGPU/materialize-block-pointer.mlir @@ -17,7 +17,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, triton_gpu.target = "xpu", t %5 = tt.load %3 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr> %6 = tt.load %4 {boundaryCheck = array, cache = 1 : i32, evict = 1 : i32, isVolatile = false, padding = 1 : i32} : !tt.ptr> - // CHECK: tt.load {{.*}} {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "column_major"} + // CHECK: tt.load {{.*}} {boundaryCheck = array, padding = 1 : i32} // CHECK: tt.load {{.*}} {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "column_major"} %7 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c1_i64, %pitch], [%c0_i32, %c0_i32] {order = array} : > %8 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c1_i64, %pitch], [%c0_i32, %c0_i32] {order = array} : > diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp index 9a0b5e4f9d..601e3694e9 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp @@ -51,6 +51,8 @@ struct TritonIntelGPUMaterializeBlockPointerPass LDBG("Found make tensor ptr op: " << makeTensorPtrOp); auto ptrType = cast(makeTensorPtrOp.getType()); auto tensorType = cast(ptrType.getPointeeType()); + auto dotLayout = ttgi::getDotEncoding(tensorType); + Operation::operand_range shape = makeTensorPtrOp.getShape(); unsigned rank = shape.size(); LDBG("Rank: " << rank); @@ -97,10 +99,26 @@ struct TritonIntelGPUMaterializeBlockPointerPass 128 / tensorType.getElementTypeBitWidth())) return; + const bool isRowMajor = fastChangeDim == rank - 1; + if (dotLayout) { + // Check if the load is being used in a dot layout, and if so is this + // the first op and is it a transposed row major matrix. If so, skip + // the block ptr attribute as performance is worse than if we remove + // the tensor pointer + LDBG("dotLayout: " << *dotLayout); + const unsigned opIdx = dotLayout->getOpIdx(); + auto dotOrder = dotLayout->getThreadOrder(); + const bool valueRowMajor = (dotOrder[0] == 1 && dotOrder[1] == 0); + if (opIdx == 0 && valueRowMajor ^ isRowMajor) { + LDBG("Skipping block pointer attribute for transposed A matrix in " + "dot operation"); + return; + } + } + loadOp->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(), - StringAttr::get(context, fastChangeDim == rank - 1 - ? "row_major" - : "column_major")); + StringAttr::get(context, isRowMajor ? "row_major" + : "column_major")); } }); }