Skip to content

Commit e38a482

Browse files
authored
[BACKEND] Fix lowering of split op with linear layout (#6031)
1 parent c6fa27b commit e38a482

File tree

2 files changed

+35
-7
lines changed

2 files changed

+35
-7
lines changed

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -173,22 +173,29 @@ struct SplitOpConversion : public ConvertOpToLLVMPattern<SplitOp> {
173173
// We rely on the following invariants of this op (which are checked by its
174174
// verifier):
175175
//
176-
// - The op has a blocked encoding.
176+
// - The layout distribute the last dimension along registers
177177
// - The last dimension (the one we're splitting) has sizePerThread=2,
178178
// threadPerWarp=1 and warpPerBlock=1.
179179
//
180180
// With these invariants, split is trivial: We can count how many contiguous
181181
// registers belong to the same chunk then we separate the registers between
182182
// two different chunks.
183+
auto srcTy = cast<RankedTensorType>(op.getSrc().getType());
184+
auto ll = toLinearLayout(srcTy.getShape(), srcTy.getEncoding());
185+
int splitDim = srcTy.getRank() - 1;
186+
auto kReg = mlir::StringAttr::get(srcTy.getContext(), "register");
187+
const auto &bases = ll.getBases();
188+
const auto &regs = bases.find(kReg)->second;
183189
int numContiguousValues = 1;
184-
auto encoding = cast<BlockedEncodingAttr>(
185-
cast<RankedTensorType>(op.getSrc().getType()).getEncoding());
186-
int splitDim = encoding.getOrder().size() - 1;
187-
for (int i = 0; i < encoding.getOrder().size(); i++) {
188-
if (encoding.getOrder()[i] == splitDim)
190+
bool found = false;
191+
for (const auto &reg : regs) {
192+
if (reg[splitDim] != 0) {
193+
found = true;
189194
break;
190-
numContiguousValues *= encoding.getSizePerThread()[i];
195+
}
196+
numContiguousValues *= 2;
191197
}
198+
assert(found && "Split dimension is not distributed along registers.");
192199
Location loc = op->getLoc();
193200
auto typeConverter = getTypeConverter();
194201
SmallVector<Value> srcVals =

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2238,3 +2238,24 @@ tt.func private @reshape_linear_layout_broadcasting(%arg0: tensor<32x4xbf16, #li
22382238
}
22392239

22402240
}
2241+
2242+
2243+
// -----
2244+
2245+
#linear1 = #ttg.linear<{register = [[0, 0, 0, 1], [0, 0, 1, 0], [0, 1, 0, 0], [16, 0, 0, 0], [32, 0, 0, 0], [64, 0, 0, 0]], lane = [[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [1, 0, 0, 0], [2, 0, 0, 0]], warp = [[4, 0, 0, 0], [8, 0, 0, 0]], block = []}>
2246+
#linear2 = #ttg.linear<{register = [[0, 0, 1], [0, 1, 0], [16, 0, 0], [32, 0, 0], [64, 0, 0]], lane = [[0, 0, 0], [0, 0, 0], [0, 0, 0], [1, 0, 0], [2, 0, 0]], warp = [[4, 0, 0], [8, 0, 0]], block = []}>
2247+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} {
2248+
// CHECK-LABEL: split_linear
2249+
tt.func @split_linear(%arg : tensor<128x2x2x2xf32, #linear1>) {
2250+
// CHECK: %[[E0:.+]] = llvm.extractvalue %{{.*}}[0]
2251+
// CHECK: %[[E1:.+]] = llvm.extractvalue %{{.*}}[1]
2252+
// CHECK: %[[E2:.+]] = llvm.extractvalue %{{.*}}[2]
2253+
// CHECK: %[[E3:.+]] = llvm.extractvalue %{{.*}}[3]
2254+
// CHECK: llvm.insertvalue %[[E0]], %{{.*}}[0]
2255+
// CHECK: llvm.insertvalue %[[E2]], %{{.*}}[1]
2256+
// CHECK: llvm.insertvalue %[[E1]], %{{.*}}[0]
2257+
// CHECK: llvm.insertvalue %[[E3]], %{{.*}}[1]
2258+
%outLHS, %outRHS = tt.split %arg : tensor<128x2x2x2xf32, #linear1> -> tensor<128x2x2xf32, #linear2>
2259+
tt.return
2260+
}
2261+
}

0 commit comments

Comments
 (0)