@@ -1003,6 +1003,172 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
10031003 vector::UnrollVectorOptions options;
10041004};
10051005
1006+ static bool isContiguousExtract (ArrayRef<int64_t > targetShape,
1007+ ArrayRef<int64_t > resultShape) {
1008+ if (targetShape.size () > resultShape.size ()) {
1009+ return false ;
1010+ }
1011+
1012+ size_t rankDiff = resultShape.size () - targetShape.size ();
1013+ // Inner dimensions must match exactly & total resultElements should be
1014+ // evenly divisible by targetElements.
1015+ for (size_t i = 1 ; i < targetShape.size (); ++i) {
1016+ if (targetShape[i] != resultShape[rankDiff + i]) {
1017+ return false ;
1018+ }
1019+ }
1020+
1021+ int64_t targetElements = ShapedType::getNumElements (targetShape);
1022+ int64_t resultElements = ShapedType::getNumElements (resultShape);
1023+ if (resultElements % targetElements != 0 ) {
1024+ return false ;
1025+ }
1026+ return true ;
1027+ }
1028+
1029+ // Calculate the shape to extract from source
1030+ static std::optional<SmallVector<int64_t >>
1031+ calculateSourceExtractShape (ArrayRef<int64_t > sourceShape,
1032+ int64_t targetElements) {
1033+ SmallVector<int64_t > extractShape;
1034+ int64_t remainingElements = targetElements;
1035+
1036+ // Build extract shape from innermost dimension outward to ensure contiguity
1037+ for (int i = sourceShape.size () - 1 ; i >= 0 && remainingElements > 1 ; --i) {
1038+ int64_t takeFromDim = std::min (remainingElements, sourceShape[i]);
1039+ extractShape.insert (extractShape.begin (), takeFromDim);
1040+
1041+ if (remainingElements % takeFromDim != 0 ) {
1042+ return std::nullopt ; // Not evenly divisible
1043+ }
1044+ remainingElements /= takeFromDim;
1045+ }
1046+
1047+ // Fill remaining dimensions with 1
1048+ while (extractShape.size () < sourceShape.size ()) {
1049+ extractShape.insert (extractShape.begin (), 1 );
1050+ }
1051+
1052+ if (ShapedType::getNumElements (extractShape) != targetElements) {
1053+ return std::nullopt ;
1054+ }
1055+
1056+ return extractShape;
1057+ }
1058+
1059+ // Convert result offsets to source offsets via linear position
1060+ static SmallVector<int64_t >
1061+ calculateSourceOffsets (ArrayRef<int64_t > resultOffsets,
1062+ ArrayRef<int64_t > sourceStrides,
1063+ ArrayRef<int64_t > resultStrides) {
1064+ // Convert result offsets to linear position
1065+ int64_t linearIndex = linearize (resultOffsets, resultStrides);
1066+ // Convert linear position to source offsets
1067+ SmallVector<int64_t > sourceOffsets = delinearize (linearIndex, sourceStrides);
1068+ return sourceOffsets;
1069+ }
1070+
1071+ // / This pattern unrolls `vector.shape_cast` operations according to the
1072+ // / provided target unroll shape. It unrolls a large shape cast into smaller
1073+ // / shape casts by extracting contiguous slices from the source vector, casting
1074+ // / each slice to the target shape, and assembling the result by inserting each
1075+ // / computed segment into the appropriate offset of the result vector.
1076+ // /
1077+ // / This pattern only applies when contiguous slices can be extracted from the
1078+ // / source vector and inserted into the result vector such that each slice
1079+ // / remains a valid vector (and not decompose to scalars). In these cases, the
1080+ // / unrolling proceeds as:
1081+ // / vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
1082+ // / vector.insert_strided_slice
1083+ // /
1084+ // / Example:
1085+ // / Given a shape cast operation:
1086+ // / %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
1087+ // /
1088+ // / and a target unroll shape of <2x4>, the pattern produces:
1089+ // /
1090+ // / %zero = arith.constant dense<0.0> : vector<4x4xf32>
1091+ // / %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
1092+ // / : vector<8x2xf32> to vector<4x2xf32>
1093+ // / %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
1094+ // / %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
1095+ // / : vector<2x4xf32> into vector<4x4xf32>
1096+ // / %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
1097+ // / : vector<8x2xf32> to vector<4x2xf32>
1098+ // / %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
1099+ // / %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
1100+ // / : vector<2x4xf32> into vector<4x4xf32>
1101+ // /
1102+ struct UnrollShapeCastPattern : public OpRewritePattern <vector::ShapeCastOp> {
1103+ UnrollShapeCastPattern (MLIRContext *context,
1104+ const vector::UnrollVectorOptions &options,
1105+ PatternBenefit benefit = 1 )
1106+ : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1107+ options (options) {}
1108+
1109+ LogicalResult matchAndRewrite (vector::ShapeCastOp shapeCastOp,
1110+ PatternRewriter &rewriter) const override {
1111+ auto targetShape = getTargetShape (options, shapeCastOp);
1112+ if (!targetShape)
1113+ return failure ();
1114+
1115+ VectorType sourceType = shapeCastOp.getSourceVectorType ();
1116+ VectorType resultType = shapeCastOp.getResultVectorType ();
1117+ ArrayRef<int64_t > sourceShape = sourceType.getShape ();
1118+ ArrayRef<int64_t > resultShape = resultType.getShape ();
1119+
1120+ if (!isContiguousExtract (*targetShape, resultShape)) {
1121+ return rewriter.notifyMatchFailure (shapeCastOp,
1122+ " Only supports cases where contiguous "
1123+ " extraction is possible" );
1124+ }
1125+
1126+ int64_t targetElements = ShapedType::getNumElements (*targetShape);
1127+
1128+ // Calculate the shape to extract from source
1129+ auto extractShape =
1130+ calculateSourceExtractShape (sourceShape, targetElements);
1131+ if (!extractShape) {
1132+ return rewriter.notifyMatchFailure (
1133+ shapeCastOp,
1134+ " cannot extract target number of elements contiguously from source" );
1135+ }
1136+
1137+ Location loc = shapeCastOp.getLoc ();
1138+
1139+ // Create result vector initialized to zero
1140+ Value result = arith::ConstantOp::create (rewriter, loc, resultType,
1141+ rewriter.getZeroAttr (resultType));
1142+
1143+ VectorType targetType =
1144+ VectorType::get (*targetShape, sourceType.getElementType ());
1145+
1146+ SmallVector<int64_t > extractStrides (extractShape->size (), 1 );
1147+ SmallVector<int64_t > insertStrides (targetShape->size (), 1 );
1148+ SmallVector<int64_t > sourceStrides = computeStrides (sourceShape);
1149+ SmallVector<int64_t > resultStrides = computeStrides (resultShape);
1150+
1151+ for (SmallVector<int64_t > resultOffsets :
1152+ StaticTileOffsetRange (resultShape, *targetShape)) {
1153+ SmallVector<int64_t > sourceOffsets =
1154+ calculateSourceOffsets (resultOffsets, sourceStrides, resultStrides);
1155+ Value sourceChunk = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
1156+ loc, shapeCastOp.getSource (), sourceOffsets, *extractShape,
1157+ extractStrides);
1158+ Value targetChunk = rewriter.createOrFold <vector::ShapeCastOp>(
1159+ loc, targetType, sourceChunk);
1160+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
1161+ loc, targetChunk, result, resultOffsets, insertStrides);
1162+ }
1163+
1164+ rewriter.replaceOp (shapeCastOp, result);
1165+ return success ();
1166+ }
1167+
1168+ private:
1169+ vector::UnrollVectorOptions options;
1170+ };
1171+
10061172} // namespace
10071173
10081174void mlir::vector::populateVectorUnrollPatterns (
@@ -1013,8 +1179,8 @@ void mlir::vector::populateVectorUnrollPatterns(
10131179 UnrollReductionPattern, UnrollMultiReductionPattern,
10141180 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
10151181 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1016- UnrollToElements, UnrollStepPattern>(patterns. getContext (),
1017- options, benefit);
1182+ UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
1183+ patterns. getContext (), options, benefit);
10181184}
10191185
10201186void mlir::vector::populateVectorToElementsUnrollPatterns (
0 commit comments