Skip to content

Commit c428109

Browse files
authored
Improve GEMM perf when one matrix is transposed (#2347)
The 2D block load/store does not work when one of the input matrices to a `tt.dot` is transposed inside the Triton kernel using the `stride` parameter. In the user example, the block pointer is transposed using stride but the `order` parameter is left unchanged. This results in `materialize-block-pointer` being unable to detect that a `block_io` attribute `column-major` should be added to the matrix. Even if this attribute were added, `rewrite-tensor-pointer` would remove the block pointer because column major was not supported. This PR adds support for detecting `column-major` based on `stride` instead of `order` and also brings the same logic to `rewrite-tensor-pointer` to allow for the column major load to be preserved and eventually lowered to a 2D block load. With this, transpose matrix performance is more inline with the non-transposed version: ``` Compute A x B (I): Detected 7680 spills, recompiling the kernel using large GRF mode (I): Kernel has now 0 spills ✅ Triton and Torch match Time for torch: 0.31821921467781067 ms Time for triton: 0.4404735863208771 ms Compute A x B.T (I): Detected 7680 spills, recompiling the kernel using large GRF mode (I): Kernel has now 0 spills ✅ Triton and Torch match Time for torch: 0.33270877599716187 ms Time for triton: 0.6352895498275757 ms ``` Close #1795
1 parent 659470b commit c428109

File tree

2 files changed

+88
-11
lines changed

2 files changed

+88
-11
lines changed

test/TritonIntelGPU/rewrite-tensor-pointer.mlir

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,3 +274,59 @@ module attributes {"triton_intel_gpu.support_sg_2d_block"} {
274274
tt.return
275275
}
276276
}
277+
278+
// -----
279+
280+
// COM: Case 4:
281+
// COM: Check that a matrix multiplication of two tensor pointers with block_io attributes is not rewritten
282+
// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
283+
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [4, 2], A = [32, 16], B = [16, 32], C = [32, 32]}>
284+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 32 : i32, triton_gpu.target = "xpu", "triton_gpu.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_bf16_conversion, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {
285+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
286+
// CHECK: @matmul_kernel_with_block_pointers
287+
%c4_i32 = arith.constant 4 : i32
288+
%c256_i32 = arith.constant 256 : i32
289+
%c1024_i64 = arith.constant 1024 : i64
290+
%c5120_i64 = arith.constant 5120 : i64
291+
%c1_i64 = arith.constant 1 : i64
292+
%c0_i32 = arith.constant 0 : i32
293+
%c4096_i64 = arith.constant 4096 : i64
294+
%c32_i32 = arith.constant 32 : i32
295+
%c64_i32 = arith.constant 64 : i32
296+
%c5120_i32 = arith.constant 5120 : i32
297+
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas>
298+
%0 = tt.get_program_id x : i32
299+
%1 = arith.divsi %0, %c64_i32 : i32
300+
%2 = arith.muli %1, %c4_i32 : i32
301+
%3 = arith.subi %c4_i32, %2 : i32
302+
%4 = arith.minsi %3, %c4_i32 : i32
303+
%5 = arith.remsi %0, %4 : i32
304+
%6 = arith.addi %2, %5 : i32
305+
%7 = arith.remsi %0, %c64_i32 : i32
306+
%8 = arith.divsi %7, %4 : i32
307+
%9 = arith.muli %6, %c256_i32 : i32
308+
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
309+
%10 = tt.make_tensor_ptr %arg0, [%c1024_i64, %c5120_i64], [%c5120_i64, %c1_i64], [%9, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
310+
%11 = arith.muli %8, %c256_i32 : i32
311+
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 0, 1>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
312+
%12 = tt.make_tensor_ptr %arg1, [%c5120_i64, %c4096_i64], [%c1_i64, %c5120_i64], [%c0_i32, %11] {order = array<i32: 0, 1>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
313+
%13:3 = scf.for %arg3 = %c0_i32 to %c5120_i32 step %c32_i32 iter_args(%arg4 = %cst, %arg5 = %10, %arg6 = %12) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>) : i32 {
314+
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
315+
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
316+
%16 = tt.load %arg5 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
317+
%17 = tt.load %arg6 {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "column_major"} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
318+
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]>
319+
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
320+
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
321+
%18 = tt.dot %16, %17, %arg4, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>> -> tensor<256x256xf32, #dpas>
322+
%19 = tt.advance %arg5, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>
323+
%20 = tt.advance %arg6, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
324+
scf.yield %18, %19, %20 : tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth = 2}>>>, !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth = 2}>>>
325+
}
326+
%14 = tt.make_tensor_ptr %arg2, [%c1024_i64, %c4096_i64], [%c4096_i64, %c1_i64], [%9, %11] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #dpas>>
327+
%15 = arith.truncf %13#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>
328+
// CHECK: tt.store {{.*}}, {{.*}}, {{.*}} : !tt.ptr<tensor<256x256xf16, #[[DPAS]]>
329+
tt.store %14, %15 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #dpas>>
330+
tt.return
331+
}
332+
}

third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ namespace mlir::triton::gpu::intel {
2323
} // namespace mlir::triton::gpu::intel
2424

2525
#define DEBUG_TYPE "tritonintelgpu-rewrite-tensor-pointer"
26+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
27+
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
2628

2729
namespace {
2830

@@ -33,41 +35,60 @@ namespace {
3335
/// - the tensor pointer pitch is not divisible by Qword bitwidth
3436
/// - the tensor pointer is not contiguous on memory
3537
bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) {
38+
LDBG("Considering removal of: " << op);
3639
if (!op->getParentOfType<ModuleOp>()->hasAttr(
3740
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName()))
3841
return true;
3942

4043
auto ptrType = cast<tt::PointerType>(op.getType());
44+
LDBG("Op ptr type: " << ptrType);
4145
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
46+
LDBG("Op tensor type: " << tensorType);
4247

4348
if (!ttgi::hasDotDpasEncoding(tensorType) &&
4449
!(isUsedByStoreOp && ttgi::hasDpasEncoding(tensorType)))
4550
return true;
4651

4752
TypedValue<triton::PointerType> base = op.getBase();
4853
Operation::operand_range shape = op.getShape();
54+
unsigned rank = shape.size();
4955
Operation::operand_range strides = op.getStrides();
5056
Operation::operand_range offsets = op.getOffsets();
5157
ArrayRef<int32_t> order = op.getOrder();
5258
ArrayRef<int64_t> tensorShape = tensorType.getShape();
5359

54-
// TODO: support column-major tensor
60+
int fastChangeDim = -1;
61+
for (size_t i = 0; i < strides.size(); ++i) {
62+
if (mlir::triton::gpu::intel::isConstant(strides[i], 1)) {
63+
fastChangeDim = i;
64+
break;
65+
}
66+
}
67+
68+
LDBG("fastChangeDim: " << fastChangeDim);
69+
if (fastChangeDim < 0) {
70+
return true;
71+
}
72+
73+
LDBG("Tensor type element type bit width: "
74+
<< tensorType.getElementTypeBitWidth());
75+
if (fastChangeDim == rank - 2 && tensorType.getElementTypeBitWidth() == 8) {
76+
// TODO: column major layout w/ fp8 has performance regression
77+
return true;
78+
}
79+
5580
// HW 2D block read instruction has restriction on pitch divisibility
56-
if (strides.size() == 2) {
57-
auto pitch = strides[0];
81+
if (fastChangeDim >= (rank - 2)) {
82+
auto pitch = strides[(fastChangeDim == rank - 1) ? rank - 2 : rank - 1];
83+
LDBG("Pitch: " << pitch);
5884
// Across Intel platforms, the strictest pitch restriction is to be a
5985
// multiple of OWord(128 bits).
60-
if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth()))
86+
if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth())) {
6187
return true;
62-
}
88+
}
6389

64-
// HW 2D block read instruction only supports contiguous accessing.
65-
auto fastChangeStride = strides[1];
66-
if (auto stride = fastChangeStride.getDefiningOp<arith::ConstantOp>()) {
67-
if (auto strideInt = dyn_cast<IntegerAttr>(stride.getValue()))
68-
return strideInt.getInt() != 1;
90+
return false;
6991
}
70-
7192
return true;
7293
}
7394

0 commit comments

Comments
 (0)