@@ -660,6 +660,7 @@ LogicalResult MemDescSubviewOp::verify() {
660660 // There are two cases:
661661 // 1. The subview is rank-reducing
662662 // - We split along the first dimension. It can be with non-constant offsets
663+ // (TODO: ajust)
663664 if (srcTy.getRank () != dstTy.getRank ()) {
664665 if (srcTy.getRank () - dstTy.getRank () != 1 ) {
665666 return emitError (
@@ -671,10 +672,6 @@ LogicalResult MemDescSubviewOp::verify() {
671672 return emitError (" only constant values are allowed outside the front "
672673 " dimension in a rank-reducing subview" );
673674 }
674- if (!value.isZero ()) {
675- return emitError (
676- " only first offset can be non-zero for a rank-reducing subview" );
677- }
678675 }
679676 return success ();
680677 }
@@ -684,14 +681,13 @@ LogicalResult MemDescSubviewOp::verify() {
684681 // - The values where the split happens must not be within the swizzling
685682 // pattern
686683 // Check which dimension we are splitting along
687- int dim = -1 ;
684+ // int dim = -1;
685+
686+ // TODO: discuss with upstream folks
687+ SetVector<int > slicedDims{};
688688 for (int i = 0 ; i < srcTy.getRank (); i++) {
689689 if (srcTy.getDimSize (i) != dstTy.getDimSize (i)) {
690- if (dim != -1 ) {
691- return emitError (
692- " We don't allow subviews that split along multiple dimensions" );
693- }
694- dim = i;
690+ slicedDims.insert (i);
695691 }
696692 }
697693 SmallVector<int64_t > offsets;
@@ -702,10 +698,24 @@ LogicalResult MemDescSubviewOp::verify() {
702698 offsets.push_back (value.getSExtValue ());
703699 }
704700 // Identity subview
705- if (dim == - 1 ) {
701+ if (slicedDims. empty () ) {
706702 return success ();
707703 }
708704
705+ for (auto [dim, offset] : llvm::enumerate (offsets)) {
706+ if (!slicedDims.contains (dim)) {
707+ if (offset != 0 ) {
708+ return emitError (" A non zero offset found in a dimension that is "
709+ " not being split" );
710+ }
711+ } else {
712+ if (offset & (dstTy.getDimSize (dim) - 1 )) {
713+ return emitError (" The split offset may not touch the tile" );
714+ }
715+ }
716+ }
717+
718+ /*
709719 for (auto [i, offset] : llvm::enumerate(offsets)) {
710720 if (i != dim) {
711721 if (offset != 0) {
@@ -718,22 +728,26 @@ LogicalResult MemDescSubviewOp::verify() {
718728 }
719729 }
720730 }
731+ */
721732 auto ctx = getContext ();
722733 // The order gives us the honest-to-goodness layout rank
723734 auto srcAllocShape = srcTy.getAllocShape ().take_back (getOrder (srcTy).size ());
724735 auto llInv =
725736 triton::gpu::toLinearLayout (srcAllocShape, srcTy.getEncoding ()).invert ();
726- auto kDim = mlir::StringAttr::get (ctx, " dim" + llvm::Twine (dim));
727- llvm::SmallVector<std::pair<mlir::StringAttr, int32_t >> namedOffsets;
728- for (auto d : standardOutDimNames (ctx, srcTy.getRank ())) {
729- namedOffsets.push_back ({d, 0 });
730- }
731- for (int dimSize = dstTy.getDimSize (dim); dimSize < srcTy.getDimSize (dim);
732- dimSize *= 2 ) {
733- namedOffsets[dim] = {kDim , dimSize};
734- if (!llvm::isPowerOf2_32 (llInv.apply (namedOffsets)[0 ].second )) {
735- return emitError (
736- " We don't support splitting along the swizzling pattern" );
737+
738+ for (auto dim : slicedDims) {
739+ auto kDim = mlir::StringAttr::get (ctx, " dim" + llvm::Twine (dim));
740+ llvm::SmallVector<std::pair<mlir::StringAttr, int32_t >> namedOffsets;
741+ for (auto d : standardOutDimNames (ctx, srcTy.getRank ())) {
742+ namedOffsets.push_back ({d, 0 });
743+ }
744+ for (int dimSize = dstTy.getDimSize (dim); dimSize < srcTy.getDimSize (dim);
745+ dimSize *= 2 ) {
746+ namedOffsets[dim] = {kDim , dimSize};
747+ if (!llvm::isPowerOf2_32 (llInv.apply (namedOffsets)[0 ].second )) {
748+ return emitError (
749+ " We don't support splitting along the swizzling pattern" );
750+ }
737751 }
738752 }
739753 return success ();
0 commit comments