@@ -1003,6 +1003,195 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
10031003 vector::UnrollVectorOptions options;
10041004};
10051005
1006+ // / Checks whether extractShape is a contiguous slice of shape.
1007+ // / For extractShape to be contiguous in shape:
1008+ // / 1) All but the leading dimension of extractShape and shape must match
1009+ // / exactly. 2) The total number of elements in shape must be evenly divisible
1010+ // / by
1011+ // / the total number of elements in extractShape.
1012+ // / Examples:
1013+ // / isContiguous([4, 4], [8, 4]) == true
1014+ // / isContiguous([2, 4], [8, 4]) == true
1015+ // / isContiguous([2, 2], [8, 4]) == false
1016+ // / Removes leading unit dimensions to handle cases like:
1017+ // / isContiguous([1, 16], [1, 32]) == true
1018+ static bool isContiguous (ArrayRef<int64_t > extractShape,
1019+ ArrayRef<int64_t > shape) {
1020+
1021+ if (extractShape.size () > shape.size ())
1022+ return false ;
1023+
1024+ while (!extractShape.empty () && extractShape.front () == 1 ) {
1025+ extractShape = extractShape.drop_front ();
1026+ }
1027+
1028+ while (!shape.empty () && shape.front () == 1 ) {
1029+ shape = shape.drop_front ();
1030+ }
1031+
1032+ size_t rankDiff = shape.size () - extractShape.size ();
1033+ if (!llvm::equal (extractShape.drop_front (), shape.drop_front (rankDiff + 1 )))
1034+ return false ;
1035+
1036+ int64_t extractElements = ShapedType::getNumElements (extractShape);
1037+ int64_t shapeElements = ShapedType::getNumElements (shape);
1038+ return shapeElements % extractElements == 0 ;
1039+ }
1040+
1041+ // / Determines what shape to use with `vector.extract_strided_slice` to extract
1042+ // / a contiguous memory region from a source vector. The extraction must be
1043+ // / contiguous and contain exactly the specified number of elements. If such an
1044+ // / extraction shape cannot be determined, returns std::nullopt.
1045+ // / EXAMPLE 1:
1046+ // / sourceShape = [16], targetElements = 8
1047+ // / Working right-to-left:
1048+ // / - Take min(8, 16) = 8 from only dim → extractShape = [8],
1049+ // / remaining = 8/8 = 1
1050+ // / Result: [8]
1051+ // /
1052+ // / EXAMPLE 2:
1053+ // / sourceShape = [4, 4], targetElements = 8
1054+ // / Working right-to-left:
1055+ // / - Take min(8, 4) = 4 from last dim → extractShape = [4],
1056+ // / remaining = 8/4 = 2
1057+ // / - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
1058+ // / remaining = 2/2 = 1
1059+ // / Result: [2, 4]
1060+ static std::optional<SmallVector<int64_t >>
1061+ calculateSourceExtractShape (ArrayRef<int64_t > sourceShape,
1062+ int64_t targetElements) {
1063+ SmallVector<int64_t > extractShape;
1064+ int64_t remainingElements = targetElements;
1065+
1066+ // Build extract shape from innermost dimension outward to ensure contiguity.
1067+ for (int i = sourceShape.size () - 1 ; i >= 0 && remainingElements > 1 ; --i) {
1068+ int64_t takeFromDim = std::min (remainingElements, sourceShape[i]);
1069+ extractShape.insert (extractShape.begin (), takeFromDim);
1070+
1071+ if (remainingElements % takeFromDim != 0 )
1072+ return std::nullopt ; // Not evenly divisible.
1073+ remainingElements /= takeFromDim;
1074+ }
1075+
1076+ // Fill remaining dimensions with 1.
1077+ while (extractShape.size () < sourceShape.size ())
1078+ extractShape.insert (extractShape.begin (), 1 );
1079+
1080+ if (ShapedType::getNumElements (extractShape) != targetElements)
1081+ return std::nullopt ;
1082+
1083+ return extractShape;
1084+ }
1085+
1086+ // Convert result offsets to source offsets via linear position.
1087+ static SmallVector<int64_t >
1088+ calculateSourceOffsets (ArrayRef<int64_t > resultOffsets,
1089+ ArrayRef<int64_t > sourceShape,
1090+ ArrayRef<int64_t > resultShape) {
1091+ // Convert result offsets to linear position.
1092+ int64_t linearIndex = linearize (resultOffsets, computeStrides (resultShape));
1093+ // Convert linear position to source offsets.
1094+ return delinearize (linearIndex, computeStrides (sourceShape));
1095+ }
1096+
1097+ // / This pattern unrolls `vector.shape_cast` operations according to the
1098+ // / provided target unroll shape. It unrolls a large shape cast into smaller
1099+ // / shape casts by extracting contiguous slices from the source vector, casting
1100+ // / each slice to the target shape, and assembling the result by inserting each
1101+ // / computed segment into the appropriate offset of the result vector.
1102+ // /
1103+ // / This pattern only applies when contiguous slices can be extracted from the
1104+ // / source vector and inserted into the result vector such that each slice
1105+ // / remains a valid vector (and not decompose to scalars). In these cases, the
1106+ // / unrolling proceeds as:
1107+ // / vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
1108+ // / vector.insert_strided_slice.
1109+ // /
1110+ // / Example:
1111+ // / Given a shape cast operation:
1112+ // / %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
1113+ // /
1114+ // / and a target unroll shape of <2x4>, the pattern produces:
1115+ // /
1116+ // / %zero = arith.constant dense<0.0> : vector<4x4xf32>
1117+ // / %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
1118+ // / : vector<8x2xf32> to vector<4x2xf32>
1119+ // / %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
1120+ // / %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
1121+ // / : vector<2x4xf32> into vector<4x4xf32>
1122+ // / %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
1123+ // / : vector<8x2xf32> to vector<4x2xf32>
1124+ // / %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
1125+ // / %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
1126+ // / : vector<2x4xf32> into vector<4x4xf32>
1127+ // /
1128+ struct UnrollShapeCastPattern : public OpRewritePattern <vector::ShapeCastOp> {
1129+ UnrollShapeCastPattern (MLIRContext *context,
1130+ const vector::UnrollVectorOptions &options,
1131+ PatternBenefit benefit = 1 )
1132+ : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
1133+ options (options) {}
1134+
1135+ LogicalResult matchAndRewrite (vector::ShapeCastOp shapeCastOp,
1136+ PatternRewriter &rewriter) const override {
1137+ std::optional<SmallVector<int64_t >> targetShape =
1138+ getTargetShape (options, shapeCastOp);
1139+ if (!targetShape)
1140+ return failure ();
1141+
1142+ VectorType sourceType = shapeCastOp.getSourceVectorType ();
1143+ VectorType resultType = shapeCastOp.getResultVectorType ();
1144+ ArrayRef<int64_t > sourceShape = sourceType.getShape ();
1145+ ArrayRef<int64_t > resultShape = resultType.getShape ();
1146+
1147+ if (!isContiguous (*targetShape, resultShape))
1148+ return rewriter.notifyMatchFailure (
1149+ shapeCastOp, " Only supports cases where target shape is "
1150+ " contiguous in result vector shape" );
1151+
1152+ int64_t targetElements = ShapedType::getNumElements (*targetShape);
1153+
1154+ // Calculate the shape to extract from source.
1155+ std::optional<SmallVector<int64_t >> extractShape =
1156+ calculateSourceExtractShape (sourceShape, targetElements);
1157+ if (!extractShape)
1158+ return rewriter.notifyMatchFailure (
1159+ shapeCastOp,
1160+ " cannot extract target number of elements contiguously from source" );
1161+
1162+ Location loc = shapeCastOp.getLoc ();
1163+
1164+ // Create result vector initialized to zero.
1165+ Value result = arith::ConstantOp::create (rewriter, loc, resultType,
1166+ rewriter.getZeroAttr (resultType));
1167+
1168+ VectorType targetType =
1169+ VectorType::get (*targetShape, sourceType.getElementType ());
1170+
1171+ SmallVector<int64_t > extractStrides (extractShape->size (), 1 );
1172+ SmallVector<int64_t > insertStrides (targetShape->size (), 1 );
1173+
1174+ for (SmallVector<int64_t > resultOffsets :
1175+ StaticTileOffsetRange (resultShape, *targetShape)) {
1176+ SmallVector<int64_t > sourceOffsets =
1177+ calculateSourceOffsets (resultOffsets, sourceShape, resultShape);
1178+ Value sourceChunk = rewriter.createOrFold <vector::ExtractStridedSliceOp>(
1179+ loc, shapeCastOp.getSource (), sourceOffsets, *extractShape,
1180+ extractStrides);
1181+ Value targetChunk = rewriter.createOrFold <vector::ShapeCastOp>(
1182+ loc, targetType, sourceChunk);
1183+ result = rewriter.createOrFold <vector::InsertStridedSliceOp>(
1184+ loc, targetChunk, result, resultOffsets, insertStrides);
1185+ }
1186+
1187+ rewriter.replaceOp (shapeCastOp, result);
1188+ return success ();
1189+ }
1190+
1191+ private:
1192+ vector::UnrollVectorOptions options;
1193+ };
1194+
10061195} // namespace
10071196
10081197void mlir::vector::populateVectorUnrollPatterns (
@@ -1013,8 +1202,8 @@ void mlir::vector::populateVectorUnrollPatterns(
10131202 UnrollReductionPattern, UnrollMultiReductionPattern,
10141203 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
10151204 UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
1016- UnrollToElements, UnrollStepPattern>(patterns. getContext (),
1017- options, benefit);
1205+ UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
1206+ patterns. getContext (), options, benefit);
10181207}
10191208
10201209void mlir::vector::populateVectorToElementsUnrollPatterns (
0 commit comments