Skip to content

Commit 3070646

Browse files
committed
[AMD] lifted upstream constraints in MemDescSubview::verify
1 parent 252d6bd commit 3070646

File tree

1 file changed

+36
-22
lines changed

1 file changed

+36
-22
lines changed

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,7 @@ LogicalResult MemDescSubviewOp::verify() {
660660
// There are two cases:
661661
// 1. The subview is rank-reducing
662662
// - We split along the first dimension. It can be with non-constant offsets
663+
// (TODO: ajust)
663664
if (srcTy.getRank() != dstTy.getRank()) {
664665
if (srcTy.getRank() - dstTy.getRank() != 1) {
665666
return emitError(
@@ -671,10 +672,6 @@ LogicalResult MemDescSubviewOp::verify() {
671672
return emitError("only constant values are allowed outside the front "
672673
"dimension in a rank-reducing subview");
673674
}
674-
if (!value.isZero()) {
675-
return emitError(
676-
"only first offset can be non-zero for a rank-reducing subview");
677-
}
678675
}
679676
return success();
680677
}
@@ -684,14 +681,13 @@ LogicalResult MemDescSubviewOp::verify() {
684681
// - The values where the split happens must not be within the swizzling
685682
// pattern
686683
// Check which dimension we are splitting along
687-
int dim = -1;
684+
// int dim = -1;
685+
686+
// TODO: discuss with upstream folks
687+
SetVector<int> slicedDims{};
688688
for (int i = 0; i < srcTy.getRank(); i++) {
689689
if (srcTy.getDimSize(i) != dstTy.getDimSize(i)) {
690-
if (dim != -1) {
691-
return emitError(
692-
"We don't allow subviews that split along multiple dimensions");
693-
}
694-
dim = i;
690+
slicedDims.insert(i);
695691
}
696692
}
697693
SmallVector<int64_t> offsets;
@@ -702,10 +698,24 @@ LogicalResult MemDescSubviewOp::verify() {
702698
offsets.push_back(value.getSExtValue());
703699
}
704700
// Identity subview
705-
if (dim == -1) {
701+
if (slicedDims.empty()) {
706702
return success();
707703
}
708704

705+
for (auto [dim, offset] : llvm::enumerate(offsets)) {
706+
if (!slicedDims.contains(dim)) {
707+
if (offset != 0) {
708+
return emitError("A non zero offset found in a dimension that is "
709+
"not being split");
710+
}
711+
} else {
712+
if (offset & (dstTy.getDimSize(dim) - 1)) {
713+
return emitError("The split offset may not touch the tile");
714+
}
715+
}
716+
}
717+
718+
/*
709719
for (auto [i, offset] : llvm::enumerate(offsets)) {
710720
if (i != dim) {
711721
if (offset != 0) {
@@ -718,22 +728,26 @@ LogicalResult MemDescSubviewOp::verify() {
718728
}
719729
}
720730
}
731+
*/
721732
auto ctx = getContext();
722733
// The order gives us the honest-to-goodness layout rank
723734
auto srcAllocShape = srcTy.getAllocShape().take_back(getOrder(srcTy).size());
724735
auto llInv =
725736
triton::gpu::toLinearLayout(srcAllocShape, srcTy.getEncoding()).invert();
726-
auto kDim = mlir::StringAttr::get(ctx, "dim" + llvm::Twine(dim));
727-
llvm::SmallVector<std::pair<mlir::StringAttr, int32_t>> namedOffsets;
728-
for (auto d : standardOutDimNames(ctx, srcTy.getRank())) {
729-
namedOffsets.push_back({d, 0});
730-
}
731-
for (int dimSize = dstTy.getDimSize(dim); dimSize < srcTy.getDimSize(dim);
732-
dimSize *= 2) {
733-
namedOffsets[dim] = {kDim, dimSize};
734-
if (!llvm::isPowerOf2_32(llInv.apply(namedOffsets)[0].second)) {
735-
return emitError(
736-
"We don't support splitting along the swizzling pattern");
737+
738+
for (auto dim : slicedDims) {
739+
auto kDim = mlir::StringAttr::get(ctx, "dim" + llvm::Twine(dim));
740+
llvm::SmallVector<std::pair<mlir::StringAttr, int32_t>> namedOffsets;
741+
for (auto d : standardOutDimNames(ctx, srcTy.getRank())) {
742+
namedOffsets.push_back({d, 0});
743+
}
744+
for (int dimSize = dstTy.getDimSize(dim); dimSize < srcTy.getDimSize(dim);
745+
dimSize *= 2) {
746+
namedOffsets[dim] = {kDim, dimSize};
747+
if (!llvm::isPowerOf2_32(llInv.apply(namedOffsets)[0].second)) {
748+
return emitError(
749+
"We don't support splitting along the swizzling pattern");
750+
}
737751
}
738752
}
739753
return success();

0 commit comments

Comments
 (0)