Skip to content

Commit 62c07a7

Browse files
authored
Preserve blocked pointers used by tt.load operation with DPAS layout (#2400)
Fixes #2378 --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 6f89dbe commit 62c07a7

File tree

2 files changed

+59
-44
lines changed

2 files changed

+59
-44
lines changed

test/TritonIntelGPU/rewrite-tensor-pointer.mlir

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@
1010
#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
1111
#dot1 = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
1212
module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} {
13-
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) {
14-
// CHECK: @matmul_kernel_with_block_pointers
13+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}, %arg9: i32 {tt.divisibility = 16 : i32}) {
1514
%c4_i32 = arith.constant 4 : i32
1615
%c256_i32 = arith.constant 256 : i32
1716
%c1_i64 = arith.constant 1 : i64
@@ -20,9 +19,9 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
2019
%c255_i32 = arith.constant 255 : i32
2120
%cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas>
2221
%0 = tt.get_program_id x : i32
23-
%1 = arith.addi %arg3, %c255_i32 : i32
22+
%1 = arith.addi %arg4, %c255_i32 : i32
2423
%2 = arith.divsi %1, %c256_i32 : i32
25-
%3 = arith.addi %arg4, %c255_i32 : i32
24+
%3 = arith.addi %arg5, %c255_i32 : i32
2625
%4 = arith.divsi %3, %c256_i32 : i32
2726
%5 = arith.muli %4, %c4_i32 : i32
2827
%6 = arith.divsi %0, %5 : i32
@@ -34,35 +33,41 @@ module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-wa
3433
%12 = arith.remsi %0, %5 : i32
3534
%13 = arith.divsi %12, %9 : i32
3635
%14 = arith.muli %11, %c256_i32 : i32
37-
%15 = arith.extsi %arg3 : i32 to i64
38-
%16 = arith.extsi %arg5 : i32 to i64
39-
%17 = arith.extsi %arg6 : i32 to i64
36+
%15 = arith.extsi %arg4 : i32 to i64
37+
%16 = arith.extsi %arg6 : i32 to i64
38+
%17 = arith.extsi %arg7 : i32 to i64
4039
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
4140
%18 = tt.make_tensor_ptr %arg0, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array<i32: 1, 0>} : <tensor<256x32xf16, #dot0>>
4241
%19 = arith.muli %13, %c256_i32 : i32
43-
%20 = arith.extsi %arg4 : i32 to i64
44-
%21 = arith.extsi %arg7 : i32 to i64
42+
%20 = arith.extsi %arg5 : i32 to i64
43+
%21 = arith.extsi %arg8 : i32 to i64
4544
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
4645
%22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #dot1>>
47-
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #dot0>>, !tt.ptr<tensor<32x256xf16, #dot1>>) : i32 {
46+
%23:3 = scf.for %arg10 = %c0_i32 to %arg6 step %c32_i32 iter_args(%arg11 = %cst, %arg12 = %18, %arg13 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #dot0>>, !tt.ptr<tensor<32x256xf16, #dot1>>) : i32 {
4847
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
4948
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
50-
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #dot0>>
51-
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #dot1>>
49+
%28 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x32xf16, #dot0>>
50+
%29 = tt.load %arg13 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #dot1>>
5251
// CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]>
5352
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>>>
5453
// CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : <tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
55-
%30 = tt.dot %28, %29, %arg10, inputPrecision = tf32 : tensor<256x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<256x256xf32, #dpas>
56-
%31 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #dot0>>
57-
%32 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #dot1>>
54+
%30 = tt.dot %28, %29, %arg11, inputPrecision = tf32 : tensor<256x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<256x256xf32, #dpas>
55+
%31 = tt.advance %arg12, [%c0_i32, %c32_i32] : <tensor<256x32xf16, #dot0>>
56+
%32 = tt.advance %arg13, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #dot1>>
5857
scf.yield %30, %31, %32 : tensor<256x256xf32, #dpas>, !tt.ptr<tensor<256x32xf16, #dot0>>, !tt.ptr<tensor<32x256xf16, #dot1>>
5958
}
60-
%24 = arith.truncf %23#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>
61-
%26 = arith.extsi %arg8 : i32 to i64
59+
%25 = arith.extsi %arg9 : i32 to i64
60+
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #[[DPAS]]>>
61+
%26 = tt.make_tensor_ptr %arg3, [%15, %20], [%25, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x256xf32, #dpas>>
62+
// CHECK: tt.load {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #[[DPAS]]>>
63+
%27 = tt.load %26 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf32, #dpas>>
64+
%28 = arith.addf %23#0, %27 : tensor<256x256xf32, #dpas>
65+
%29 = arith.truncf %28 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas>
66+
6267
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #[[DPAS]]>>
63-
%27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #dpas>>
68+
%30 = tt.make_tensor_ptr %arg2, [%15, %20], [%25, %c1_i64], [%14, %19] {order = array<i32: 1, 0>} : <tensor<256x256xf16, #dpas>>
6469
// CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #[[DPAS]]>>
65-
tt.store %27, %24 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #dpas>>
70+
tt.store %30, %29 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<256x256xf16, #dpas>>
6671
tt.return
6772
}
6873
}

third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp

Lines changed: 35 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,28 @@ namespace {
3030

3131
/// Check if the tensor pointer should be removed. The tensor pointer should be
3232
/// removed if:
33-
/// - the tensor pointer does not have DotEncoding with DpasEncoding parent
34-
/// and does not have DpasEncoding
35-
/// - the tensor pointer pitch is not divisible by Qword bitwidth
36-
/// - the tensor pointer is not contiguous on memory
37-
bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) {
33+
/// - it does not have Dpas layout or Dot layout (with Dpas layout as parent)
34+
/// - its pitch is not divisible by Qword bitwidth
35+
/// - it is not contiguous in memory
36+
bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByLoadOrStoreOp) {
3837
LDBG("Considering removal of: " << op);
3938
if (!op->getParentOfType<ModuleOp>()->hasAttr(
40-
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName()))
39+
ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName())) {
40+
LDBG("Marked for removal: 2D block operation not supported");
4141
return true;
42+
}
4243

4344
auto ptrType = cast<tt::PointerType>(op.getType());
4445
LDBG("Op ptr type: " << ptrType);
4546
auto tensorType = cast<RankedTensorType>(ptrType.getPointeeType());
4647
LDBG("Op tensor type: " << tensorType);
4748

4849
if (!ttgi::hasDotDpasEncoding(tensorType) &&
49-
!(isUsedByStoreOp && ttgi::hasDpasEncoding(tensorType)))
50+
!(isUsedByLoadOrStoreOp && ttgi::hasDpasEncoding(tensorType))) {
51+
LDBG("Marked for removal: tensor doesn't have DPAS layout and is not used "
52+
"by load or store op with DPAS layout");
5053
return true;
54+
}
5155

5256
TypedValue<triton::PointerType> base = op.getBase();
5357
Operation::operand_range shape = op.getShape();
@@ -60,21 +64,23 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) {
6064

6165
int fastChangeDim = -1;
6266
for (size_t i = 0; i < strides.size(); ++i) {
63-
if (mlir::triton::gpu::intel::isConstant(strides[i], 1)) {
67+
if (ttgi::isConstant(strides[i], 1)) {
6468
fastChangeDim = i;
6569
break;
6670
}
6771
}
6872

6973
LDBG("fastChangeDim: " << fastChangeDim);
7074
if (fastChangeDim < 0) {
75+
LDBG("Marked for removal: fast changing dimension not found");
7176
return true;
7277
}
7378

7479
LDBG("Tensor type element type bit width: "
7580
<< tensorType.getElementTypeBitWidth());
7681
if (fastChangeDim == rank - 2 && tensorType.getElementTypeBitWidth() == 8) {
7782
// TODO: column major layout w/ fp8 has performance regression
83+
LDBG("Marked for removal: column major layout with fp8 element type");
7884
return true;
7985
}
8086

@@ -85,11 +91,15 @@ bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) {
8591
// Across Intel platforms, the strictest pitch restriction is to be a
8692
// multiple of OWord(128 bits).
8793
if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth())) {
94+
LDBG("Marked for removal: cannot use block read/write instructions");
8895
return true;
8996
}
9097

9198
return false;
9299
}
100+
101+
LDBG("Marked for removal: fall-trough");
102+
93103
return true;
94104
}
95105

@@ -705,28 +715,28 @@ class TritonIntelGPURewriteTensorPointerPass
705715
void runOnOperation() override {
706716
ModuleOp mod = getOperation();
707717

708-
auto usedByStoreOp = [](Value val) {
718+
auto usedByLoadOrStoreOp = [](Value val) {
709719
return llvm::any_of(val.getUsers(), [](Operation *user) {
710-
return llvm::isa<tt::StoreOp>(user);
720+
return isa<tt::LoadOp, tt::StoreOp>(user);
711721
});
712722
};
713723

714-
auto markTensorPointerForRemoval = [this](Value val,
715-
bool isUsedByStoreOp = false) {
716-
if (tt::isTensorPointerType(val.getType())) {
717-
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
718-
if (shouldRemove(makeTensorPtrOp, isUsedByStoreOp))
719-
valueToRemove.insert(val);
720-
}
721-
};
724+
auto markTensorPointerForRemoval =
725+
[this](Value val, bool isUsedByLoadOrStoreOp = false) {
726+
if (tt::isTensorPointerType(val.getType())) {
727+
tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val);
728+
if (shouldRemove(makeTensorPtrOp, isUsedByLoadOrStoreOp))
729+
valueToRemove.insert(val);
730+
}
731+
};
722732

723733
mod.walk([&](Operation *op) {
724-
if (llvm::isa<tt::MakeTensorPtrOp>(op)) {
734+
if (isa<tt::MakeTensorPtrOp>(op)) {
725735
Value result = op->getResult(0);
726-
markTensorPointerForRemoval(result, usedByStoreOp(result));
727-
} else if (llvm::isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
736+
markTensorPointerForRemoval(result, usedByLoadOrStoreOp(result));
737+
} else if (isa<tt::AdvanceOp, tt::LoadOp, tt::StoreOp>(op)) {
728738
markTensorPointerForRemoval(op->getOperand(0),
729-
llvm::isa<tt::StoreOp>(op));
739+
isa<tt::LoadOp, tt::StoreOp>(op));
730740
} else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
731741
for (auto arg : forOp.getInitArgs())
732742
markTensorPointerForRemoval(arg);
@@ -738,11 +748,11 @@ class TritonIntelGPURewriteTensorPointerPass
738748

739749
LLVM_DEBUG({
740750
if (valueToRemove.empty())
741-
llvm::dbgs() << "No tensor pointer to remove\n";
751+
DBGS() << "No tensor pointer to remove";
742752
else {
743-
llvm::dbgs() << "Values to remove: \n";
753+
DBGS() << "Values to remove: ";
744754
for (auto val : valueToRemove)
745-
llvm::dbgs() << val << "\n";
755+
DBGS() << val;
746756
}
747757
});
748758

0 commit comments

Comments
 (0)