Skip to content

Commit 056ad7f

Browse files
authored
[BACKEND] Error when using CGA>1 on memdesc_subview (#7288)
Before we were simply discarding the block during the lowering.
1 parent 16fa1a8 commit 056ad7f

File tree

3 files changed

+26
-5
lines changed

3 files changed

+26
-5
lines changed

lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -504,9 +504,13 @@ struct MemDescSubviewOpConversion
504504
// The order gives us the honest-to-goodness layout rank
505505
auto srcAllocShape =
506506
srcTy.getAllocShape().take_back(getOrder(srcTy).size());
507-
auto llInv = toLinearLayout(srcAllocShape, srcTy.getEncoding()).invert();
508-
offset =
509-
applyLinearLayout(loc, rewriter, llInv, logicalOffsets)[0].second;
507+
auto ll = toLinearLayout(srcAllocShape, srcTy.getEncoding());
508+
// Checked in the verifier.
509+
assert(ll.getInDimSize(str_attr("block")) == 1);
510+
auto kOffset = str_attr("offset");
511+
ll = ll.reshapeIns({{kOffset, ll.getTotalInDimSize()}});
512+
offset = applyLinearLayout(loc, rewriter, ll.invert(), logicalOffsets)[0]
513+
.second;
510514
}
511515

512516
auto base = smemObj.getBase();

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -721,8 +721,14 @@ LogicalResult MemDescSubviewOp::verify() {
721721
auto ctx = getContext();
722722
// The order gives us the honest-to-goodness layout rank
723723
auto srcAllocShape = srcTy.getAllocShape().take_back(getOrder(srcTy).size());
724-
auto llInv =
725-
triton::gpu::toLinearLayout(srcAllocShape, srcTy.getEncoding()).invert();
724+
auto ll = triton::gpu::toLinearLayout(srcAllocShape, srcTy.getEncoding());
725+
// NYI: We don't support non-trivial block dimension for now.
726+
auto kBlock = mlir::StringAttr::get(getContext(), "block");
727+
if (ll.getInDimSize(kBlock) != 1) {
728+
return emitError("non-trivial block dimension not supported");
729+
}
730+
731+
auto llInv = ll.invert();
726732
auto kDim = mlir::StringAttr::get(ctx, "dim" + llvm::Twine(dim));
727733
llvm::SmallVector<std::pair<mlir::StringAttr, int32_t>> namedOffsets;
728734
for (auto d : standardOutDimNames(ctx, srcTy.getRank())) {

test/TritonGPU/invalid.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
// RUN: triton-opt --split-input-file %s --verify-diagnostics
22

3+
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0], CTAsPerCGA = [2, 1], CTASplitNum = [1, 1], CTAOrder = [1, 0]}>
4+
#smem = #ttg.shared_memory
5+
tt.func public @non_trivial_block(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) {
6+
%zero = arith.constant 0 : i32
7+
// expected-error @+1 {{non-trivial block}}
8+
%a = ttg.memdesc_subview %arg0[%zero, %zero] : !ttg.memdesc<8x16xf32, #shared, #smem> -> !ttg.memdesc<8x8xf32, #shared, #smem>
9+
tt.return
10+
}
11+
12+
// -----
13+
314
#shared = #ttg.swizzled_shared<{vec = 8, perPhase = 1, maxPhase = 4, order = [0, 1]}>
415
#smem = #ttg.shared_memory
516
tt.func public @miss_encoding(%arg0: !ttg.memdesc<8x16xf32, #shared, #smem>) {

0 commit comments

Comments
 (0)