Skip to content

[RemoveLayoutConversions]: Reduce loop carried values - part 2 #4921

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,66 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}
tt.return
}
}

// -----

// CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
// CHECK: #[[BLOCKED1:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: #[[DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
tt.func public @reduce_loop_carried_values(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg5: i32) {
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c32_i32 = arith.constant 32 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #dpas>
%18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
%22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
// COM: Ensure there are only 3 loop results and not layout conversion in the loop.
// CHECK: [[LOOP_RES:%.*]]:3 = scf.for
// CHECK-NOT: ttg.convert_layout
// CHECK: scf.yield
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%30 = ttg.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
%31 = ttg.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
%32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
%33 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<64x32xf16, #blocked>>
%34 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked1>>
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
}
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
%27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #dpas>>
tt.store %27, %24 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #dpas>>

// CHECK: [[LOAD1:%.*]] = tt.load [[LOOP_RES]]#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>>
// CHECK: [[CONV1:%.*]] = ttg.convert_layout [[LOAD1]] : tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>> -> tensor<64x32xf16, #[[BLOCKED]]>
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #[[BLOCKED]]>>
// CHECK: tt.store [[PTR]], [[CONV1]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #[[BLOCKED]]>>
%28 = tt.load %23#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
tt.store %29, %28 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>

// CHECK: [[LOAD2:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #[[BLOCKED]]>>
// CHECK: [[CONV2:%.*]] = ttg.convert_layout [[LOAD2]] : tensor<64x32xf16, #[[BLOCKED]]> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>
// CHECK: tt.store [[LOOP_RES]]#1, [[CONV2]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>>
%30 = tt.load %29 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
tt.store %23#1, %30 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>

// CHECK: [[ADV:%.*]] = tt.advance [[LOOP_RES]]#2, {{.*}} : <tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: [[LOAD3:%.*]] = tt.load [[ADV]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: [[CONV3:%.*]] = ttg.convert_layout [[LOAD3]] : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<32x256xf16, #[[BLOCKED1]]>
// CHECK: tt.store {{.*}}, [[CONV3]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #[[BLOCKED1]]>
%31 = tt.advance %23#2, [%c0_i32, %c32_i32] : <tensor<32x256xf16, #blocked1>>
%32 = tt.load %31 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<32x256xf16, #blocked1>>
%33 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
tt.store %33, %32 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/TypeSwitch.h"
#include <deque>

namespace mlir::triton::gpu::intel {
Expand Down Expand Up @@ -166,6 +167,7 @@ class LayoutRematerialization {

private:
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
void reduceLoopCarriedValues();
// Existing tuples of (value, layout) that needs to be updated when recreating
// scf ops. This prevents keeping track of Values that have been delete when
// rewriting slices.
Expand Down Expand Up @@ -1008,6 +1010,93 @@ void LayoutRematerialization::updateRematMapping(
}
}

/// Reduce loop carried values if the value is used after the loop and can be
/// removed by using another loop yielded value plus a convert layout operation.
void LayoutRematerialization::reduceLoopCarriedValues() {
for (auto [pair, val] : rematMapping) {
if (!isa<BlockArgument>(pair.first))
continue;

auto arg = cast<BlockArgument>(pair.first);
if (!isTensorPointerType(arg.getType()))
continue;

auto loopOp = dyn_cast<LoopLikeOpInterface>(arg.getOwner()->getParentOp());
if (!loopOp)
continue;

// Loop arguments that corresponds to a loop result which is not used are
// not interesting.
OpResult loopRes = loopOp.getTiedLoopResult(arg);
if (loopRes.getNumUses() == 0)
continue;

std::function<void(Operation *, Value)> processUser = [&](Operation *user,
Value rematRes) {
Location loc = user->getLoc();
OpBuilder rewriter(user);

TypeSwitch<Operation *>(user)
.Case<LoadOp>([&](auto loadOp) {
auto newLoadOp =
rewriter.create<LoadOp>(loc, rematRes, loadOp->getAttrs());
auto convOp = rewriter.create<ConvertLayoutOp>(
loc, loadOp.getType(), newLoadOp.getResult());
loadOp->replaceAllUsesWith(convOp);
opToDelete.insert(loadOp);
LLVM_DEBUG({
DBGS() << "Replaced:\n\t" << *loadOp << "\n"
<< "with:\n\t" << *newLoadOp << "\n"
<< "\t" << *convOp << "\n";
});
})
.Case<StoreOp>([&](auto storeOp) {
Value data = storeOp.getOperand(1);
auto dataType = cast<RankedTensorType>(data.getType());
auto newPtrType = cast<PointerType>(rematRes.getType());
Attribute encoding =
cast<RankedTensorType>(newPtrType.getPointeeType())
.getEncoding();
RankedTensorType newDataType = dataType.cloneWithEncoding(encoding);
auto convOp =
rewriter.create<ConvertLayoutOp>(loc, newDataType, data);
auto newStoreOp = rewriter.create<StoreOp>(
loc, rematRes, convOp, storeOp.getBoundaryCheck(),
storeOp.getCache(), storeOp.getEvict());
opToDelete.insert(storeOp);
LLVM_DEBUG({
DBGS() << "Replaced:\n\t" << *storeOp << "\n"
<< "with:\n\t" << *convOp << "\n"
<< "\t" << *newStoreOp << "\n";
});
})
.Case<AdvanceOp>([&](auto advanceOp) {
auto newAdvanceOp = rewriter.create<AdvanceOp>(
loc, rematRes.getType(), rematRes, advanceOp.getOffsets());
opToDelete.insert(advanceOp);
LLVM_DEBUG({
DBGS() << "Replaced:\n\t" << *advanceOp << "\n"
<< "with:\n\t" << *newAdvanceOp << "\n";
});

for (Operation *user : advanceOp->getUsers())
processUser(user, newAdvanceOp.getResult());
})
.Default([](auto op) {
llvm::report_fatal_error(llvm::Twine(
"Unsupported operation in backward rematerialization: '" +
op->getName().getStringRef() + "'"));
});
};

// Replace the loop result corresponding to the argument with an
// equivalent loop result.
OpResult rematRes = loopOp.getTiedLoopResult(cast<BlockArgument>(val));
for (Operation *user : loopRes.getUsers())
processUser(user, rematRes);
}
}

void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp,
Expand Down Expand Up @@ -1267,6 +1356,8 @@ void LayoutRematerialization::backwardRematerialization() {
convertOp.getResult());
}
}

reduceLoopCarriedValues();
}

void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
Expand Down