Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,66 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}
tt.return
}
}

// -----

// CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
// CHECK: #[[BLOCKED1:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
// CHECK: #[[DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
tt.func public @reduce_loop_carried_values(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg5: i32) {
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c32_i32 = arith.constant 32 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #dpas>
%18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
%22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
// COM: Ensure there are only 3 loop results and not layout conversion in the loop.
// CHECK: [[LOOP_RES:%.*]]:3 = scf.for
// CHECK-NOT: ttg.convert_layout
// CHECK: scf.yield
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%30 = ttg.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
%31 = ttg.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
%32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
%33 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<64x32xf16, #blocked>>
%34 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked1>>
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
}
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
%27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #dpas>>
tt.store %27, %24 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #dpas>>

// CHECK: [[LOAD1:%.*]] = tt.load [[LOOP_RES]]#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>>
// CHECK: [[CONV1:%.*]] = ttg.convert_layout [[LOAD1]] : tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>> -> tensor<64x32xf16, #[[BLOCKED]]>
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #[[BLOCKED]]>>
// CHECK: tt.store [[PTR]], [[CONV1]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #[[BLOCKED]]>>
%28 = tt.load %23#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
tt.store %29, %28 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>

// CHECK: [[LOAD2:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #[[BLOCKED]]>>
// CHECK: [[CONV2:%.*]] = ttg.convert_layout [[LOAD2]] : tensor<64x32xf16, #[[BLOCKED]]> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>
// CHECK: tt.store [[LOOP_RES]]#1, [[CONV2]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>>
%30 = tt.load %29 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
tt.store %23#1, %30 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>

// CHECK: [[ADV:%.*]] = tt.advance [[LOOP_RES]]#2, {{.*}} : <tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: [[LOAD3:%.*]] = tt.load [[ADV]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>>>
// CHECK: [[CONV3:%.*]] = ttg.convert_layout [[LOAD3]] : tensor<32x256xf16, #ttg.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<32x256xf16, #[[BLOCKED1]]>
// CHECK: tt.store {{.*}}, [[CONV3]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #[[BLOCKED1]]>
%31 = tt.advance %23#2, [%c0_i32, %c32_i32] : <tensor<32x256xf16, #blocked1>>
%32 = tt.load %31 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<32x256xf16, #blocked1>>
%33 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
tt.store %33, %32 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<32x256xf16, #blocked1>>
tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/TypeSwitch.h"
#include <deque>

namespace mlir::triton::gpu::intel {
Expand Down Expand Up @@ -166,6 +167,7 @@ class LayoutRematerialization {

private:
void updateRematMapping(SmallVector<std::tuple<Value, Value>> &values);
void reduceLoopCarriedValues();
// Existing tuples of (value, layout) that needs to be updated when recreating
// scf ops. This prevents keeping track of Values that have been delete when
// rewriting slices.
Expand Down Expand Up @@ -1008,6 +1010,93 @@ void LayoutRematerialization::updateRematMapping(
}
}

/// Reduce loop carried values if the value is used after the loop and can be
/// removed by using another loop yielded value plus a convert layout operation.
void LayoutRematerialization::reduceLoopCarriedValues() {
for (auto [pair, val] : rematMapping) {
if (!isa<BlockArgument>(pair.first))
continue;

auto arg = cast<BlockArgument>(pair.first);
if (!isTensorPointerType(arg.getType()))
continue;

auto loopOp = dyn_cast<LoopLikeOpInterface>(arg.getOwner()->getParentOp());
if (!loopOp)
continue;

// Loop arguments that corresponds to a loop result which is not used are
// not interesting.
OpResult loopRes = loopOp.getTiedLoopResult(arg);
if (loopRes.getNumUses() == 0)
continue;

std::function<void(Operation *, Value)> processUser = [&](Operation *user,
Value rematRes) {
Location loc = user->getLoc();
OpBuilder rewriter(user);

TypeSwitch<Operation *>(user)
.Case<LoadOp>([&](auto loadOp) {
auto newLoadOp =
rewriter.create<LoadOp>(loc, rematRes, loadOp->getAttrs());
auto convOp = rewriter.create<ConvertLayoutOp>(
loc, loadOp.getType(), newLoadOp.getResult());
loadOp->replaceAllUsesWith(convOp);
opToDelete.insert(loadOp);
LLVM_DEBUG({
DBGS() << "Replaced:\n\t" << *loadOp << "\n"
<< "with:\n\t" << *newLoadOp << "\n"
<< "\t" << *convOp << "\n";
});
})
.Case<StoreOp>([&](auto storeOp) {
Value data = storeOp.getOperand(1);
auto dataType = cast<RankedTensorType>(data.getType());
auto newPtrType = cast<PointerType>(rematRes.getType());
Attribute encoding =
cast<RankedTensorType>(newPtrType.getPointeeType())
.getEncoding();
RankedTensorType newDataType = dataType.cloneWithEncoding(encoding);
auto convOp =
rewriter.create<ConvertLayoutOp>(loc, newDataType, data);
auto newStoreOp = rewriter.create<StoreOp>(
loc, rematRes, convOp, storeOp.getBoundaryCheck(),
storeOp.getCache(), storeOp.getEvict());
opToDelete.insert(storeOp);
LLVM_DEBUG({
DBGS() << "Replaced:\n\t" << *storeOp << "\n"
<< "with:\n\t" << *convOp << "\n"
<< "\t" << *newStoreOp << "\n";
});
})
.Case<AdvanceOp>([&](auto advanceOp) {
auto newAdvanceOp = rewriter.create<AdvanceOp>(
loc, rematRes.getType(), rematRes, advanceOp.getOffsets());
opToDelete.insert(advanceOp);
LLVM_DEBUG({
DBGS() << "Replaced:\n\t" << *advanceOp << "\n"
<< "with:\n\t" << *newAdvanceOp << "\n";
});

for (Operation *user : advanceOp->getUsers())
processUser(user, newAdvanceOp.getResult());
})
.Default([](auto op) {
llvm::report_fatal_error(llvm::Twine(
"Unsupported operation in backward rematerialization: '" +
op->getName().getStringRef() + "'"));
Copy link
Preview

Copilot AI Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error message in the Default case uses string concatenation with Twine which may not work as expected. Consider using llvm::formatv or constructing the error message differently: llvm::formatv("Unsupported operation in backward rematerialization: '{0}'", op->getName().getStringRef())

Suggested change
op->getName().getStringRef() + "'"));
llvm::report_fatal_error(
llvm::formatv("Unsupported operation in backward rematerialization: '{0}'", op->getName().getStringRef()));

Copilot uses AI. Check for mistakes.

});
};

// Replace the loop result corresponding to the argument with an
// equivalent loop result.
OpResult rematRes = loopOp.getTiedLoopResult(cast<BlockArgument>(val));
Copy link
Preview

Copilot AI Aug 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cast(val) is redundant and potentially unsafe. The variable val is already confirmed to be of type Value from the rematMapping, but the code already has arg which is the properly cast BlockArgument from line 1020. Use arg instead of casting val.

Suggested change
OpResult rematRes = loopOp.getTiedLoopResult(cast<BlockArgument>(val));
OpResult rematRes = loopOp.getTiedLoopResult(arg);

Copilot uses AI. Check for mistakes.

for (Operation *user : loopRes.getUsers())
processUser(user, rematRes);
}
}

void LayoutRematerialization::rewriteSlice(SetVector<Value> &slice,
DenseMap<Value, Attribute> &layout,
ConvertLayoutOp convertOp,
Expand Down Expand Up @@ -1267,6 +1356,8 @@ void LayoutRematerialization::backwardRematerialization() {
convertOp.getResult());
}
}

reduceLoopCarriedValues();
}

void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
Expand Down
Loading