Skip to content

Commit d5736e4

Browse files
authored
Enhance the remove layout conversion pass for store operations (#4308)
This PR adds logic in LayoutPropagation::rewriteStoreOp to ensure layout conversions are removed only for store operations that use the value yielded by the layout conversion. Store operations that use the same base pointer (directly or indirectly) but do not use the converted value are left untouched. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent 2426ba7 commit d5736e4

File tree

2 files changed

+73
-12
lines changed

2 files changed

+73
-12
lines changed

test/TritonIntelGPU/combine.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2518,3 +2518,49 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr
25182518
tt.return
25192519
}
25202520
}
2521+
2522+
// -----
2523+
2524+
// CHECK-DAG: #[[$BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
2525+
// CHECK-DAG: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [4, 1], A = [32, 8], B = [8, 16], C = [32, 16]}>
2526+
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
2527+
#mma = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [4, 1], A = [32, 8], B = [8, 16], C = [32, 16]}>
2528+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, ttig.support_sg_2d_block} {
2529+
// CHECK-LABEL: matmul_kernel_reshape
2530+
tt.func public @matmul_kernel_reshape(%arg2: !tt.ptr<f32>, %arg3: i32, %arg4: i32) {
2531+
%cst = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked>
2532+
%c32_i32 = arith.constant 32 : i32
2533+
%c0_i32 = arith.constant 0 : i32
2534+
%c1_i32 = arith.constant 1 : i32
2535+
%c1_i64 = arith.constant 1 : i64
2536+
%cst_0 = arith.constant dense<1.000000e+00> : tensor<64x64xf32, #mma>
2537+
%1 = arith.extsi %arg4 : i32 to i64
2538+
%2 = arith.extsi %arg3 : i32 to i64
2539+
2540+
// CHECK-DAG: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #[[$DPAS]]>>
2541+
// CHECK-DAG: [[PTR2:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #[[$DPAS]]>>
2542+
// CHECK-DAG: [[PTR3:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #[[$BLOCKED]]>>
2543+
2544+
// CHECK-NOT: separator of consecutive DAGs
2545+
// CHECK-DAG: [[ADV_PTR2:%.*]] = tt.advance [[PTR2]], {{.*}} : <tensor<64x64xf32, #[[$DPAS]]>>
2546+
// CHECK-DAG: [[ADV_PTR3:%.*]] = tt.advance [[PTR3]], {{.*}} : <tensor<64x64xf32, #[[$BLOCKED]]>>
2547+
%3 = tt.make_tensor_ptr %arg2, [%2, %1], [%1, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #blocked>>
2548+
%4 = tt.advance %3, [%c0_i32, %c32_i32] : !tt.ptr<tensor<64x64xf32, #blocked>>
2549+
2550+
// The following 2 stores should use blocked layout.
2551+
// CHECK-NOT: separator of consecutive DAGs
2552+
// CHECK-DAG: tt.store [[PTR3]], {{.*}} : !tt.ptr<tensor<64x64xf32, #[[$BLOCKED]]>>
2553+
// CHECK-DAG: tt.store [[ADV_PTR3]], {{.*}} : !tt.ptr<tensor<64x64xf32, #[[$BLOCKED]]>>
2554+
tt.store %3, %cst : !tt.ptr<tensor<64x64xf32, #blocked>>
2555+
tt.store %4, %cst : !tt.ptr<tensor<64x64xf32, #blocked>>
2556+
2557+
// The following 2 stores should use mma layout
2558+
// CHECK-NOT: ttg.convert_layout
2559+
// CHECK-DAG: tt.store [[PTR1]], {{.*}} : !tt.ptr<tensor<64x64xf32, #[[$DPAS]]>>
2560+
// CHECK-DAG: tt.store [[ADV_PTR2]], {{.*}} : !tt.ptr<tensor<64x64xf32, #[[$DPAS]]>>
2561+
%5 = ttg.convert_layout %cst_0 : tensor<64x64xf32, #mma> -> tensor<64x64xf32, #blocked>
2562+
tt.store %3, %5 : !tt.ptr<tensor<64x64xf32, #blocked>>
2563+
tt.store %4, %5 : !tt.ptr<tensor<64x64xf32, #blocked>>
2564+
tt.return
2565+
}
2566+
}

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -694,20 +694,22 @@ void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) {
694694

695695
// Recursively update the operands in a chain of AdvanceOps, after setting the
696696
// pointer operand of the first one.
697-
static void updateAdvanceOpChain(AdvanceOp advanceOp, Value makeTensorPtrOp,
698-
Value dataToStore) {
697+
static void updateAdvanceOpChain(AdvanceOp advanceOp, StoreOp storeOp,
698+
Value makeTensorPtrOp, Value dataToStore) {
699699
OpBuilder rewriter(advanceOp);
700700
auto newAdvanceOp =
701701
rewriter.create<AdvanceOp>(advanceOp.getLoc(), makeTensorPtrOp.getType(),
702702
makeTensorPtrOp, advanceOp.getOffsets());
703703

704704
SmallVector<Operation *> advanceOpUsers(advanceOp->getUsers());
705705
for (Operation *user : advanceOpUsers) {
706-
if (auto storeOp = dyn_cast<StoreOp>(user)) {
707-
storeOp.setOperand(0, newAdvanceOp);
708-
storeOp.setOperand(1, dataToStore);
706+
if (auto storeUser = dyn_cast<StoreOp>(user)) {
707+
if (storeUser == storeOp) {
708+
storeOp.setOperand(0, newAdvanceOp);
709+
storeOp.setOperand(1, dataToStore);
710+
}
709711
} else if (auto advanceOp = dyn_cast<AdvanceOp>(user)) {
710-
updateAdvanceOpChain(advanceOp, makeTensorPtrOp, dataToStore);
712+
updateAdvanceOpChain(advanceOp, storeOp, makeTensorPtrOp, dataToStore);
711713
} else {
712714
llvm::errs() << "user: " << *user << "\n";
713715
llvm_unreachable("Unexpected user");
@@ -794,14 +796,27 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) {
794796
// Update the store operation with the new layout.
795797
SmallVector<Operation *> makeTensorPtrOpUsers(makeTensorPtrOp->getUsers());
796798
Value dataToStore = getValueAs(value, encoding);
797-
Block *storeBB = storeOp->getBlock();
798799
for (Operation *user : makeTensorPtrOpUsers) {
799-
Block *userBB = user->getBlock();
800-
if (auto storeOp = dyn_cast<StoreOp>(user)) {
801-
storeOp.setOperand(0, newMakeTensorPtrOp);
802-
storeOp.setOperand(1, dataToStore);
800+
if (auto storeUser = dyn_cast<StoreOp>(user)) {
801+
if (storeUser == storeOp) {
802+
storeOp.setOperand(0, newMakeTensorPtrOp);
803+
storeOp.setOperand(1, dataToStore);
804+
}
803805
} else if (auto advanceOp = dyn_cast<AdvanceOp>(user)) {
804-
updateAdvanceOpChain(advanceOp, newMakeTensorPtrOp, dataToStore);
806+
auto chainIsTerminatedByCurrentStore = [&](AdvanceOp advanceOp) {
807+
AdvanceOp currentAdvOp = advanceOp;
808+
for (Operation *user : currentAdvOp->getUsers()) {
809+
if (isa<StoreOp>(user) && cast<StoreOp>(user) == storeOp)
810+
return true;
811+
if (isa<AdvanceOp>(user))
812+
currentAdvOp = cast<AdvanceOp>(user);
813+
}
814+
return false;
815+
};
816+
817+
if (chainIsTerminatedByCurrentStore(advanceOp))
818+
updateAdvanceOpChain(advanceOp, storeOp, newMakeTensorPtrOp,
819+
dataToStore);
805820
}
806821
}
807822

0 commit comments

Comments
 (0)