Skip to content

Commit 8f8b91b

Browse files
authored
[RemoveLayoutConversions]: Add support for tt.store operation using block ptr updated by tt.advance operation (#4277)
This PR adds support for a new "tt.store" operation by updating the layout conversion process for tensor pointer operations. --------- Signed-off-by: Tiotto, Ettore <[email protected]>
1 parent d9b6052 commit 8f8b91b

File tree

2 files changed

+99
-18
lines changed

2 files changed

+99
-18
lines changed

test/TritonIntelGPU/combine.mlir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2472,3 +2472,49 @@ module attributes {ttig.min_sg_size = 16 : i32, ttig.support_bf16_conversion, tt
24722472
tt.return
24732473
}
24742474
}
2475+
2476+
// -----
2477+
2478+
// COM: Test that the DPAS layout is propagated to the store operation in the presence of an advance operation updating its base pointer.
2479+
// CHECK-NOT: #ttg.blocked<{.*}>
2480+
// CHECK: #[[$DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
2481+
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
2482+
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
2483+
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
2484+
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
2485+
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
2486+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
2487+
// CHECK-LABEL: matmul_kernel_with_block_pointers
2488+
tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32, %arg8: i32) {
2489+
%c1_i64 = arith.constant 1 : i64
2490+
%c0_i32 = arith.constant 0 : i32
2491+
%c0_i64 = arith.constant 0 : i64
2492+
%c32_i32 = arith.constant 32 : i32
2493+
%cst = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #blocked1>
2494+
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[$DPAS]], kWidth = 1}>>>
2495+
// CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[$DPAS]], kWidth = 2}>>>
2496+
%18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
2497+
%22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
2498+
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #blocked1>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
2499+
// CHECK-NOT: ttg.convert_layout
2500+
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
2501+
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
2502+
%36 = ttg.convert_layout %arg10 : tensor<64x256xf32, #blocked1> -> tensor<64x256xf32, #dpas>
2503+
%30 = ttg.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
2504+
%31 = ttg.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
2505+
%32 = tt.dot %30, %31, %36, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
2506+
%33 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<64x32xf16, #blocked>>
2507+
%34 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked1>>
2508+
%35 = ttg.convert_layout %32 : tensor<64x256xf32, #dpas> -> tensor<64x256xf32, #blocked1>
2509+
scf.yield %35, %33, %34 : tensor<64x256xf32, #blocked1>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
2510+
}
2511+
%24 = arith.truncf %23#0 : tensor<64x256xf32, #blocked1> to tensor<64x256xf16, #blocked1>
2512+
// CHECK: [[PTR1:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #[[$DPAS]]>>
2513+
// CHECK: [[PTR2:%.*]] = tt.advance [[PTR1]], {{.*}} : <tensor<64x256xf16, #[[$DPAS]]>>
2514+
// CHECK: tt.store [[PTR2]], {{.*}} {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #[[$DPAS]]>>
2515+
%27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #blocked1>>
2516+
%newptr = tt.advance %27, [%c32_i32, %c32_i32] : <tensor<64x256xf16, #blocked1>>
2517+
tt.store %newptr, %24 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #blocked1>>
2518+
tt.return
2519+
}
2520+
}

third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp

Lines changed: 53 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,29 @@ void LayoutPropagation::rewriteAssertOp(AssertOp assertOp) {
692692
assertOp->setOperand(0, newOperand);
693693
}
694694

695+
// Recursively update the operands in a chain of AdvanceOps, after setting the
696+
// pointer operand of the first one.
697+
static void updateAdvanceOpChain(AdvanceOp advanceOp, Value makeTensorPtrOp,
698+
Value dataToStore) {
699+
OpBuilder rewriter(advanceOp);
700+
auto newAdvanceOp =
701+
rewriter.create<AdvanceOp>(advanceOp.getLoc(), makeTensorPtrOp.getType(),
702+
makeTensorPtrOp, advanceOp.getOffsets());
703+
704+
SmallVector<Operation *> advanceOpUsers(advanceOp->getUsers());
705+
for (Operation *user : advanceOpUsers) {
706+
if (auto storeOp = dyn_cast<StoreOp>(user)) {
707+
storeOp.setOperand(0, newAdvanceOp);
708+
storeOp.setOperand(1, dataToStore);
709+
} else if (auto advanceOp = dyn_cast<AdvanceOp>(user)) {
710+
updateAdvanceOpChain(advanceOp, makeTensorPtrOp, dataToStore);
711+
} else {
712+
llvm::errs() << "user: " << *user << "\n";
713+
llvm_unreachable("Unexpected user");
714+
}
715+
}
716+
}
717+
695718
bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) {
696719
// Disable 2D block store on LTS.
697720
if (!storeOp->getParentOfType<ModuleOp>()->hasAttr(
@@ -705,13 +728,16 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) {
705728
if (!isTensorPointerType(ptr.getType()))
706729
return false;
707730

708-
// 2D block store are preceeded by a MakeTensorPtrOp
709-
auto makeTensorPtrOp = ptr.getDefiningOp<MakeTensorPtrOp>();
710-
if (!makeTensorPtrOp)
711-
return false;
731+
// Locate the operation that created the block pointer.
732+
Operation *defOp = ptr.getDefiningOp();
733+
while (auto advanceOp = dyn_cast<AdvanceOp>(defOp))
734+
defOp = advanceOp.getPtr().getDefiningOp();
735+
assert(isa<MakeTensorPtrOp>(defOp) &&
736+
"MakeTensorPtrOp should be the only op that creates a tensor pointer");
737+
auto makeTensorPtrOp = cast<MakeTensorPtrOp>(defOp);
712738

713-
// DPAS encoding have to be propagate if conversion from DPAS to
714-
// other has been done before.
739+
// DPAS encoding have to be propagated if conversion from a DPAS layout to
740+
// another layout has been done before.
715741
auto convertOp = storeOp.getValue().getDefiningOp<ConvertLayoutOp>();
716742
PointerType newPtrType;
717743
Attribute encoding;
@@ -758,21 +784,26 @@ bool LayoutPropagation::rewriteStoreOp(StoreOp storeOp) {
758784
encoding = convertOpSrcType.getEncoding();
759785
}
760786

761-
// We create a new MakeTensorPtrOp with the new data type.
787+
// Create a new MakeTensorPtrOp with the new layout.
762788
OpBuilder rewriter(makeTensorPtrOp);
763-
Value newStorePtr = rewriter.create<MakeTensorPtrOp>(
789+
Value newMakeTensorPtrOp = rewriter.create<MakeTensorPtrOp>(
764790
makeTensorPtrOp.getLoc(), newPtrType, makeTensorPtrOp.getBase(),
765791
makeTensorPtrOp.getShape(), makeTensorPtrOp.getStrides(),
766-
makeTensorPtrOp.getOffsets(), rewriter.getDenseI32ArrayAttr({1, 0}));
767-
768-
// The encoding of the StoreOp is updated with the new
769-
// operands:
770-
// - the Ptr created by the MakeTensorPtrOp with the new data
771-
// type
772-
// - the forwarded DPAS encoding.
773-
Value newOperand = getValueAs(value, encoding);
774-
storeOp.setOperand(0, newStorePtr);
775-
storeOp.setOperand(1, newOperand);
792+
makeTensorPtrOp.getOffsets(), makeTensorPtrOp.getOrderAttr());
793+
794+
// Update the store operation with the new layout.
795+
SmallVector<Operation *> makeTensorPtrOpUsers(makeTensorPtrOp->getUsers());
796+
Value dataToStore = getValueAs(value, encoding);
797+
Block *storeBB = storeOp->getBlock();
798+
for (Operation *user : makeTensorPtrOpUsers) {
799+
Block *userBB = user->getBlock();
800+
if (auto storeOp = dyn_cast<StoreOp>(user)) {
801+
storeOp.setOperand(0, newMakeTensorPtrOp);
802+
storeOp.setOperand(1, dataToStore);
803+
} else if (auto advanceOp = dyn_cast<AdvanceOp>(user)) {
804+
updateAdvanceOpChain(advanceOp, newMakeTensorPtrOp, dataToStore);
805+
}
806+
}
776807

777808
// If the DPAS encoding is forwarded, we do not need the
778809
// convertOp anymore if the convertOp was only used by the
@@ -1607,6 +1638,7 @@ class TritonIntelGPURemoveLayoutConversionsPass
16071638
LLVM_DEBUG({
16081639
DBGS() << "Module after propagating layouts forward:\n";
16091640
m.dump();
1641+
assert(succeeded(verify(m)) && "Module verification failed");
16101642
});
16111643

16121644
cleanupConvertOps();
@@ -1617,6 +1649,7 @@ class TritonIntelGPURemoveLayoutConversionsPass
16171649
LLVM_DEBUG({
16181650
DBGS() << "Module after backward remat:\n";
16191651
m.dump();
1652+
assert(succeeded(verify(m)) && "Module verification failed");
16201653
});
16211654

16221655
// Cleanup dummy converts created during backward remat.
@@ -1628,6 +1661,7 @@ class TritonIntelGPURemoveLayoutConversionsPass
16281661
LLVM_DEBUG({
16291662
DBGS() << "Module after hoisting converts:\n";
16301663
m.dump();
1664+
assert(succeeded(verify(m)) && "Module verification failed");
16311665
});
16321666

16331667
// 4. Apply clean up patterns to remove remove dead convert and dead code
@@ -1643,6 +1677,7 @@ class TritonIntelGPURemoveLayoutConversionsPass
16431677
LLVM_DEBUG({
16441678
DBGS() << "Module after final cleanups:\n";
16451679
m.dump();
1680+
assert(succeeded(verify(m)) && "Module verification failed");
16461681
});
16471682
}
16481683
};

0 commit comments

Comments
 (0)