Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,57 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}
tt.return
}
}

// -----

// CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
// CHECK: #[[DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
tt.func public @reduce_loop_carried_values(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg5: i32) {
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c32_i32 = arith.constant 32 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #dpas>
%18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
%22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
// COM: Ensure there are only 3 loop results and not layout conversion in the loop.
// CHECK: [[LOOP_RES:%.*]]:3 = scf.for
// CHECK-NOT: ttg.convert_layout
// CHECK: scf.yield
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%30 = ttg.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
%31 = ttg.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
%32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
%33 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<64x32xf16, #blocked>>
%34 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked1>>
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
}
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
%27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #dpas>>
tt.store %27, %24 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #dpas>>
Comment on lines +381 to +383
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like these 3 lines can be removed?


// CHECK: [[LOAD1:%.*]] = tt.load [[LOOP_RES]]#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>>
// CHECK: [[CONV1:%.*]] = ttg.convert_layout [[LOAD1]] : tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>> -> tensor<64x32xf16, #[[BLOCKED]]>
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #[[BLOCKED]]>>
// CHECK: tt.store [[PTR]], [[CONV1]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #[[BLOCKED]]>>
%28 = tt.load %23#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
tt.store %29, %28 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>

// CHECK: [[LOAD2:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #[[BLOCKED]]>>
// CHECK: [[CONV2:%.*]] = ttg.convert_layout [[LOAD2]] : tensor<64x32xf16, #[[BLOCKED]]> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>
// CHECK: tt.store [[LOOP_RES]]#1, [[CONV2]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>>
%30 = tt.load %29 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
tt.store %23#1, %30 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>

tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/TypeSwitch.h"
#include <deque>

namespace mlir::triton::gpu::intel {
Expand Down Expand Up @@ -1267,6 +1268,77 @@ void LayoutRematerialization::backwardRematerialization() {
convertOp.getResult());
}
}

// Reduce loop carried values if the value can be removed by using another
// loop yielded value plus a convert layout operation.
for (auto [pair, val] : rematMapping) {
if (!isa<BlockArgument>(pair.first))
continue;

auto arg = cast<BlockArgument>(pair.first);
Comment on lines +1275 to +1278
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!isa<BlockArgument>(pair.first))
continue;
auto arg = cast<BlockArgument>(pair.first);
auto arg = dyn_cast<BlockArgument>(pair.first);
if (!arg)
continue;

if (!isTensorPointerType(arg.getType()))
continue;

if (auto loopOp =
dyn_cast<LoopLikeOpInterface>(arg.getOwner()->getParentOp())) {
// Loop arguments that corresponds to a loop result which is not used are
// not interesting.
OpResult loopRes = loopOp.getTiedLoopResult(arg);
if (loopRes.getNumUses() == 0)
continue;

// Replace the loop result corresponding to the argument with an
// equivalent loop result.
auto rematArg = cast<BlockArgument>(val);
OpResult rematRes = loopOp.getTiedLoopResult(rematArg);

for (OpOperand &use : loopRes.getUses()) {
Operation *user = use.getOwner();
Comment on lines +1295 to +1296
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Suggested change
for (OpOperand &use : loopRes.getUses()) {
Operation *user = use.getOwner();
for (Operation *user : loopRes.getUsers()) {

Location loc = user->getLoc();
OpBuilder rewriter(user);

TypeSwitch<Operation *>(user)
.Case<LoadOp>([&](auto loadOp) {
auto newLoadOp =
rewriter.create<LoadOp>(loc, rematRes, loadOp->getAttrs());
auto convOp = rewriter.create<ConvertLayoutOp>(
loc, loadOp.getType(), newLoadOp.getResult());
loadOp->replaceAllUsesWith(convOp);
LLVM_DEBUG({
DBGS() << "Replaced:\n\t" << *loadOp << "\n";
DBGS() << "with:\n\t" << *newLoadOp << "\n"
<< "\t" << *convOp << "\n";
});
})
.Case<StoreOp>([&](auto storeOp) {
Value data = storeOp.getOperand(1);
auto dataType = cast<RankedTensorType>(data.getType());
auto newPtrType = cast<PointerType>(rematRes.getType());
Attribute encoding =
cast<RankedTensorType>(newPtrType.getPointeeType())
.getEncoding();
RankedTensorType newDataType =
dataType.cloneWithEncoding(encoding);
auto convOp =
rewriter.create<ConvertLayoutOp>(loc, newDataType, data);
auto newStoreOp = rewriter.create<StoreOp>(
loc, rematRes, convOp, storeOp.getBoundaryCheck(),
storeOp.getCache(), storeOp.getEvict());
opToDelete.insert(storeOp);
LLVM_DEBUG({
DBGS() << "Replaced:\n\t" << *storeOp << "\n";
DBGS() << "with:\n\t" << *convOp << "\n"
<< "\t" << *newStoreOp << "\n";
});
})
.Default([](auto op) {
llvm::report_fatal_error(
llvm::Twine("Unsupported operation: '" +
op->getName().getStringRef() + "'"));
});
}
}
}
}

void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
Expand Down
Loading