Skip to content

Commit 0341d75

Browse files
committed
[AMD] lifted upstream constraints in MemDescSubview::verify
1 parent 252d6bd commit 0341d75

File tree

1 file changed

+24
-24
lines changed

1 file changed

+24
-24
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,7 @@ LogicalResult MemDescSubviewOp::verify() {
660660
// There are two cases:
661661
// 1. The subview is rank-reducing
662662
// - We split along the first dimension. It can be with non-constant offsets
663+
// (TODO: ajust)
663664
if (srcTy.getRank() != dstTy.getRank()) {
664665
if (srcTy.getRank() - dstTy.getRank() != 1) {
665666
return emitError(
@@ -671,10 +672,6 @@ LogicalResult MemDescSubviewOp::verify() {
671672
return emitError("only constant values are allowed outside the front "
672673
"dimension in a rank-reducing subview");
673674
}
674-
if (!value.isZero()) {
675-
return emitError(
676-
"only first offset can be non-zero for a rank-reducing subview");
677-
}
678675
}
679676
return success();
680677
}
@@ -684,14 +681,13 @@ LogicalResult MemDescSubviewOp::verify() {
684681
// - The values where the split happens must not be within the swizzling
685682
// pattern
686683
// Check which dimension we are splitting along
687-
int dim = -1;
684+
// int dim = -1;
685+
686+
// TODO: discuss with upstream folks
687+
SetVector<int> slicedDims{};
688688
for (int i = 0; i < srcTy.getRank(); i++) {
689689
if (srcTy.getDimSize(i) != dstTy.getDimSize(i)) {
690-
if (dim != -1) {
691-
return emitError(
692-
"We don't allow subviews that split along multiple dimensions");
693-
}
694-
dim = i;
690+
slicedDims.insert(i);
695691
}
696692
}
697693
SmallVector<int64_t> offsets;
@@ -702,12 +698,12 @@ LogicalResult MemDescSubviewOp::verify() {
702698
offsets.push_back(value.getSExtValue());
703699
}
704700
// Identity subview
705-
if (dim == -1) {
701+
if (slicedDims.empty()) {
706702
return success();
707703
}
708704

709-
for (auto [i, offset] : llvm::enumerate(offsets)) {
710-
if (i != dim) {
705+
for (auto [dim, offset] : llvm::enumerate(offsets)) {
706+
if (!slicedDims.contains(dim)) {
711707
if (offset != 0) {
712708
return emitError("A non zero offset found in a dimension that is "
713709
"not being split");
@@ -718,22 +714,26 @@ LogicalResult MemDescSubviewOp::verify() {
718714
}
719715
}
720716
}
717+
721718
auto ctx = getContext();
722719
// The order gives us the honest-to-goodness layout rank
723720
auto srcAllocShape = srcTy.getAllocShape().take_back(getOrder(srcTy).size());
724721
auto llInv =
725722
triton::gpu::toLinearLayout(srcAllocShape, srcTy.getEncoding()).invert();
726-
auto kDim = mlir::StringAttr::get(ctx, "dim" + llvm::Twine(dim));
727-
llvm::SmallVector<std::pair<mlir::StringAttr, int32_t>> namedOffsets;
728-
for (auto d : standardOutDimNames(ctx, srcTy.getRank())) {
729-
namedOffsets.push_back({d, 0});
730-
}
731-
for (int dimSize = dstTy.getDimSize(dim); dimSize < srcTy.getDimSize(dim);
732-
dimSize *= 2) {
733-
namedOffsets[dim] = {kDim, dimSize};
734-
if (!llvm::isPowerOf2_32(llInv.apply(namedOffsets)[0].second)) {
735-
return emitError(
736-
"We don't support splitting along the swizzling pattern");
723+
724+
for (auto dim : slicedDims) {
725+
auto kDim = mlir::StringAttr::get(ctx, "dim" + llvm::Twine(dim));
726+
llvm::SmallVector<std::pair<mlir::StringAttr, int32_t>> namedOffsets;
727+
for (auto d : standardOutDimNames(ctx, srcTy.getRank())) {
728+
namedOffsets.push_back({d, 0});
729+
}
730+
for (int dimSize = dstTy.getDimSize(dim); dimSize < srcTy.getDimSize(dim);
731+
dimSize *= 2) {
732+
namedOffsets[dim] = {kDim, dimSize};
733+
if (!llvm::isPowerOf2_32(llInv.apply(namedOffsets)[0].second)) {
734+
return emitError(
735+
"We don't support splitting along the swizzling pattern");
736+
}
737737
}
738738
}
739739
return success();

0 commit comments

Comments
 (0)