Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -345,3 +345,57 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32}
tt.return
}
}

// -----

// CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
// CHECK: #[[DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}>
#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} {
tt.func public @reduce_loop_carried_values(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f16>, %arg5: i32) {
%c1_i64 = arith.constant 1 : i64
%c0_i32 = arith.constant 0 : i32
%c0_i64 = arith.constant 0 : i64
%c32_i32 = arith.constant 32 : i32
%cst = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #dpas>
%18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
%22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x256xf16, #blocked1>>
%23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>) : i32 {
// COM: Ensure there are only 3 loop results and not layout conversion in the loop.
// CHECK: [[LOOP_RES:%.*]]:3 = scf.for
// CHECK-NOT: ttg.convert_layout
// CHECK: scf.yield
%28 = tt.load %arg11 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.load %arg12 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<32x256xf16, #blocked1>>
%30 = ttg.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0>
%31 = ttg.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1>
%32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas>
%33 = tt.advance %arg11, [%c0_i32, %c32_i32] : <tensor<64x32xf16, #blocked>>
%34 = tt.advance %arg12, [%c32_i32, %c0_i32] : <tensor<32x256xf16, #blocked1>>
scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr<tensor<64x32xf16, #blocked>>, !tt.ptr<tensor<32x256xf16, #blocked1>>
}
%24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas>
%27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x256xf16, #dpas>>
tt.store %27, %24 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x256xf16, #dpas>>

// CHECK: [[LOAD1:%.*]] = tt.load [[LOOP_RES]]#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>>
// CHECK: [[CONV1:%.*]] = ttg.convert_layout [[LOAD1]] : tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>> -> tensor<64x32xf16, #[[BLOCKED]]>
// CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #[[BLOCKED]]>>
// CHECK: tt.store [[PTR]], [[CONV1]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #[[BLOCKED]]>>
%28 = tt.load %23#1 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
%29 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #blocked>>
tt.store %29, %28 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>

// CHECK: [[LOAD2:%.*]] = tt.load [[PTR]] {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #[[BLOCKED]]>>
// CHECK: [[CONV2:%.*]] = ttg.convert_layout [[LOAD2]] : tensor<64x32xf16, #[[BLOCKED]]> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>
// CHECK: tt.store [[LOOP_RES]]#1, [[CONV2]] {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>>>
%30 = tt.load %29 {boundaryCheck = array<i32: 0, 1>, ttig.block_io = "row_major" } : !tt.ptr<tensor<64x32xf16, #blocked>>
tt.store %23#1, %30 {boundaryCheck = array<i32: 0, 1>} : !tt.ptr<tensor<64x32xf16, #blocked>>

tt.return
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
#include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h"
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
#include "llvm/ADT/TypeSwitch.h"
#include <deque>

namespace mlir::triton::gpu::intel {
Expand Down Expand Up @@ -1267,6 +1268,78 @@ void LayoutRematerialization::backwardRematerialization() {
convertOp.getResult());
}
}

// Reduce loop carried values if the value can be removed by using another
// loop yielded value plus a convert layout operation.
for (auto [pair, val] : rematMapping) {
if (!isa<BlockArgument>(pair.first))
continue;

auto arg = cast<BlockArgument>(pair.first);
if (!isTensorPointerType(arg.getType()))
continue;

if (auto loopOp =
dyn_cast<LoopLikeOpInterface>(arg.getOwner()->getParentOp())) {
// Loop arguments that corresponds to a loop result which is not used are
// not interesting.
OpResult loopRes = loopOp.getTiedLoopResult(arg);
if (loopRes.getNumUses() == 0)
continue;

// Replace the loop result corresponding to the argument with an
// equivalent loop result.
auto rematArg = cast<BlockArgument>(val);
OpResult rematRes = loopOp.getTiedLoopResult(rematArg);

for (OpOperand &use : loopRes.getUses()) {
Operation *user = use.getOwner();
Location loc = user->getLoc();
OpBuilder rewriter(user);

TypeSwitch<Operation *>(user)
.Case<LoadOp>([&](auto loadOp) {
auto newLoadOp =
rewriter.create<LoadOp>(loc, rematRes, loadOp->getAttrs());
auto convOp = rewriter.create<ConvertLayoutOp>(
loc, loadOp.getType(), newLoadOp.getResult());
loadOp->replaceAllUsesWith(convOp);
opToDelete.insert(loadOp);
LLVM_DEBUG({
DBGS() << "Replaced:\n\t" << *loadOp << "\n";
DBGS() << "with:\n\t" << *newLoadOp << "\n"
<< "\t" << *convOp << "\n";
});
})
.Case<StoreOp>([&](auto storeOp) {
Value data = storeOp.getOperand(1);
auto dataType = cast<RankedTensorType>(data.getType());
auto newPtrType = cast<PointerType>(rematRes.getType());
Attribute encoding =
cast<RankedTensorType>(newPtrType.getPointeeType())
.getEncoding();
RankedTensorType newDataType =
dataType.cloneWithEncoding(encoding);
auto convOp =
rewriter.create<ConvertLayoutOp>(loc, newDataType, data);
auto newStoreOp = rewriter.create<StoreOp>(
loc, rematRes, convOp, storeOp.getBoundaryCheck(),
storeOp.getCache(), storeOp.getEvict());
opToDelete.insert(storeOp);
LLVM_DEBUG({
DBGS() << "Replaced:\n\t" << *storeOp << "\n";
DBGS() << "with:\n\t" << *convOp << "\n"
<< "\t" << *newStoreOp << "\n";
});
})
.Default([](auto op) {
llvm::report_fatal_error(llvm::Twine(
"Unsupported operation in backward rematerialization: '" +
op->getName().getStringRef() + "'"));
});
}
}
}
}

void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
Expand Down
Loading