@@ -1134,22 +1134,45 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11341134
11351135LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite (
11361136 linalg::PackOp packOp, PatternRewriter &rewriter) const {
1137- // TODO: support the case that outer dimensions are not all 1s. A
1138- // tensor.expand_shape will be generated in this case.
1139- if (llvm::any_of (packOp.getAllOuterDims (),
1137+ if (llvm::any_of (packOp.getTiledOuterDims (),
11401138 [](int64_t dim) { return dim != 1 ; })) {
11411139 return rewriter.notifyMatchFailure (
11421140 packOp, " not all outer dimensions of the result are 1s" );
11431141 }
11441142
1143+ ArrayRef<int64_t > innerDimsPos = packOp.getInnerDimsPos ();
1144+ auto outerDimsPerm = packOp.getOuterDimsPerm ();
1145+
1146+ // Verify that there are no:
1147+ // * non-unit + un-tiled-outer-dims,
1148+ // that are permuted. Supporting such cases would require refining the logic
1149+ // that generates the Transpose Op.
1150+ if (!llvm::all_of (outerDimsPerm, [&innerDimsPos, &packOp](int64_t dim) {
1151+ static int prev = 0 ;
1152+ // Skip tiled dims - these can be permuted.
1153+ if (llvm::is_contained (innerDimsPos, dim))
1154+ return true ;
1155+
1156+ // Check whether this dim has been permuted. Permuting unit dims is fine
1157+ // as that's effectively a no-op.
1158+ if (dim < prev && (packOp.getType ().getShape ()[prev] != 1 ||
1159+ packOp.getType ().getShape ()[dim] != 1 ))
1160+ return false ;
1161+
1162+ prev = dim;
1163+ return true ;
1164+ })) {
1165+ return rewriter.notifyMatchFailure (
1166+ packOp, " At least one non-unit and un-tiled outer dim is permuted, "
1167+ " this is not supported ATM!" );
1168+ }
1169+
11451170 Attribute zeroIdxAttr = rewriter.getIndexAttr (0 );
11461171 Attribute oneIdxAttr = rewriter.getIndexAttr (1 );
11471172 Location loc = packOp.getLoc ();
11481173
11491174 int64_t srcRank = packOp.getSourceRank ();
11501175 int64_t destRank = packOp.getDestRank ();
1151- ArrayRef<int64_t > innerDimsPos = packOp.getInnerDimsPos ();
1152- int64_t numberOfTiles = innerDimsPos.size ();
11531176
11541177 // 1. Get the input that is going to be packed. If the input requires padding,
11551178 // add a padding operation and return that as the input.
@@ -1160,10 +1183,13 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11601183 // %transposed_tile = linalg.transpose ins(%source_or_padded_source),
11611184 // outs(%init)
11621185 // Assumptions made:
1163- // - All outer dims are 1 - the corresponding transposition order doesn't
1164- // matter, but requires all dim indices to be present.
1186+ // - All tiled outer dims are 1 - the corresponding transposition order
1187+ // doesn't matter, but requires all dim indices to be present.
1188+ // - Un-tiled outer dims remain un-permuted.
11651189
1166- // 2.1 Get the permutation for linalg.transpose
1190+ // 2.1 Get the permutation for linalg.transpose:
1191+ // [ untiled-dims, inner-dims-pos ]
1192+ // Note, this logic assumes that the untiled dims are not permuted.
11671193 SmallVector<int64_t > srcPermForTranspose;
11681194 for (int64_t i = 0 ; i < srcRank; i++) {
11691195 // We assume the `k` dimensions of the inner dim position, where `k` is the
@@ -1179,9 +1205,21 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11791205 }
11801206 srcPermForTranspose.append (innerDimsPos.begin (), innerDimsPos.end ());
11811207
1182- // 2.2 Create the init tensor for linalg.transpose with the correct shape
1183- SmallVector<OpFoldResult> shapeForEmptyOp (srcRank - numberOfTiles,
1184- oneIdxAttr);
1208+ // 2.2 Create the init tensor for linalg.transpose with the correct shape:
1209+ // [ untiled-dims, tiled-dims ]
1210+ ShapedType inputTy = cast<ShapedType>(input.getType ());
1211+ SmallVector<OpFoldResult> shapeForEmptyOp;
1212+ for (int64_t i = 0 ; i < srcRank; i++) {
1213+ if (llvm::is_contained (innerDimsPos, i)) {
1214+ // The tiled dims are appended after this loop.
1215+ continue ;
1216+ }
1217+ if (inputTy.isStaticDim (i))
1218+ shapeForEmptyOp.push_back (rewriter.getIndexAttr (inputTy.getShape ()[i]));
1219+ else
1220+ shapeForEmptyOp.emplace_back (
1221+ tensor::DimOp::create (rewriter, loc, input, i).getResult ());
1222+ }
11851223 shapeForEmptyOp.append (packOp.getMixedTiles ());
11861224
11871225 // getMixedTiles() may contain Values pointing to constant ops, not the
@@ -1204,25 +1242,36 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12041242 auto transposedOp = linalg::TransposeOp::create (rewriter, loc, input, empty,
12051243 srcPermForTranspose);
12061244
1207- // 3. Insert the inner tile to the destination:
1245+ // 3. Insert the inner tile into the destination tensor :
12081246 // %inserted_tile = tensor.insert_slice(%transposed_tile)
1209- SmallVector<OpFoldResult> writeStrides (destRank, oneIdxAttr);
1210- SmallVector<OpFoldResult> writeOffsets (destRank, zeroIdxAttr);
1211- // Outer dims are all 1s!
1212- SmallVector<OpFoldResult> writeSizes (destRank - numberOfTiles, oneIdxAttr);
1213- SmallVector<int64_t > writeShape;
1247+
1248+ // Compute the sizes attribute:
1249+ // [ outer-dims, tile-sizes ]
1250+ // Note that the output from the transpose Op excludes the tiled outer dims.
1251+ // However, given the assumption that:
1252+ // * all tiled outer dims == 1,
1253+ // we can just use a rank-expanding tensor.insert_slice.
1254+ SmallVector<OpFoldResult> writeSizes;
1255+ for (auto size : packOp.getAllOuterDims ()) {
1256+ writeSizes.push_back (rewriter.getIndexAttr (size));
1257+ }
12141258
12151259 for (auto tileSize : packOp.getMixedTiles ()) {
1216- auto [tileSizeStatic , tileSizeOfr] =
1260+ auto [_ , tileSizeOfr] =
12171261 getSimplifiedOfrAndStaticSizePair (tileSize, rewriter);
12181262 writeSizes.push_back (tileSizeOfr);
1219- writeShape.push_back (tileSizeStatic);
12201263 }
12211264
1222- // 4. Replace tensor.packOp with tensor.insert_slice created above
1265+ // TODO: Add a constructor for tensor.insert_slice that doesn't require
1266+ // strides nor offsets.
1267+ SmallVector<OpFoldResult> writeStrides (destRank, oneIdxAttr);
1268+ SmallVector<OpFoldResult> writeOffsets (destRank, zeroIdxAttr);
1269+
12231270 auto insert = tensor::InsertSliceOp::create (
12241271 rewriter, loc, transposedOp.getResult ()[0 ], packOp.getDest (),
12251272 writeOffsets, writeSizes, writeStrides);
1273+
1274+ // 4. Replace tensor.packOp with tensor.insert_slice created above
12261275 rewriter.replaceOp (packOp, insert.getResult ());
12271276
12281277 return success ();
0 commit comments