@@ -1140,13 +1140,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11401140 packOp, " not all outer dimensions of the result are 1s" );
11411141 }
11421142
1143+ ArrayRef<int64_t > innerDimsPos = packOp.getInnerDimsPos ();
1144+ auto outerDimsPerm = packOp.getOuterDimsPerm ();
1145+
1146+ // Verify that there are no non-unit un-tiled outer dims that are permuted.
1147+ // Supporting such cases will require refining the logic to generate the
1148+ // Transpose Op.
1149+ if (!llvm::all_of (outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
1150+ static int prev = 0 ;
1151+ // Tiled dims are not relevant here.
1152+ if (llvm::is_contained (innerDimsPos, dim))
1153+ return true ;
1154+ // Was this dim permuted? Note, permuting unit dims is fine.
1155+ if (dim < prev && (packOp.getType ().getShape ()[prev] != 1 ||
1156+ packOp.getType ().getShape ()[dim] != 1 ))
1157+ return false ;
1158+
1159+ prev = dim;
1160+ return true ;
1161+ })) {
1162+ return rewriter.notifyMatchFailure (
1163+ packOp, " At least one non-unit and un-tiled outer dim is permuted, "
1164+ " this is not supported ATM!" );
1165+ }
1166+
11431167 Attribute zeroIdxAttr = rewriter.getIndexAttr (0 );
11441168 Attribute oneIdxAttr = rewriter.getIndexAttr (1 );
11451169 Location loc = packOp.getLoc ();
11461170
11471171 int64_t srcRank = packOp.getSourceRank ();
11481172 int64_t destRank = packOp.getDestRank ();
1149- ArrayRef<int64_t > innerDimsPos = packOp.getInnerDimsPos ();
11501173
11511174 // 1. Get the input that is going to be packed. If the input requires padding,
11521175 // add a padding operation and return that as the input.
@@ -1185,8 +1208,10 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11851208 ShapedType inputTy = cast<ShapedType>(input.getType ());
11861209 SmallVector<OpFoldResult> shapeForEmptyOp;
11871210 for (int64_t i = 0 ; i < srcRank; i++) {
1188- if (llvm::is_contained (innerDimsPos, i))
1211+ if (llvm::is_contained (innerDimsPos, i)) {
1212+ // The tiled dims are appended after this loop.
11891213 continue ;
1214+ }
11901215 if (inputTy.isStaticDim (i))
11911216 shapeForEmptyOp.push_back (rewriter.getIndexAttr (inputTy.getShape ()[i]));
11921217 else
@@ -1231,15 +1256,15 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12311256 }
12321257
12331258 for (auto tileSize : packOp.getMixedTiles ()) {
1234- auto [tileSizeStatic , tileSizeOfr] =
1259+ auto [_ , tileSizeOfr] =
12351260 getSimplifiedOfrAndStaticSizePair (tileSize, rewriter);
12361261 writeSizes.push_back (tileSizeOfr);
12371262 }
12381263
12391264 SmallVector<OpFoldResult> writeStrides (destRank, oneIdxAttr);
12401265 SmallVector<OpFoldResult> writeOffsets (destRank, zeroIdxAttr);
12411266
1242- // TODO: A constructor that doesn't require strised nor offsets.
1267+ // TODO: A constructor that doesn't require strides nor offsets.
12431268 auto insert = tensor::InsertSliceOp::create (
12441269 rewriter, loc, transposedOp.getResult ()[0 ], packOp.getDest (),
12451270 writeOffsets, writeSizes, writeStrides);
0 commit comments