@@ -1005,67 +1005,57 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
10051005
10061006static bool isContiguousExtract (ArrayRef<int64_t > targetShape,
10071007 ArrayRef<int64_t > resultShape) {
1008- if (targetShape.size () > resultShape.size ()) {
1008+ if (targetShape.size () > resultShape.size ())
10091009 return false ;
1010- }
10111010
10121011 size_t rankDiff = resultShape.size () - targetShape.size ();
10131012 // Inner dimensions must match exactly & total resultElements should be
10141013 // evenly divisible by targetElements.
1015- for (size_t i = 1 ; i < targetShape.size (); ++i) {
1016- if (targetShape[i] != resultShape[rankDiff + i]) {
1017- return false ;
1018- }
1019- }
1014+ if (!llvm::equal (targetShape.drop_front (),
1015+ resultShape.drop_front (rankDiff + 1 )))
1016+ return false ;
10201017
10211018 int64_t targetElements = ShapedType::getNumElements (targetShape);
10221019 int64_t resultElements = ShapedType::getNumElements (resultShape);
1023- if (resultElements % targetElements != 0 ) {
1024- return false ;
1025- }
1026- return true ;
1020+ return resultElements % targetElements == 0 ;
10271021}
10281022
1029- // Calculate the shape to extract from source
1023+ // Calculate the shape to extract from source.
10301024static std::optional<SmallVector<int64_t >>
10311025calculateSourceExtractShape (ArrayRef<int64_t > sourceShape,
10321026 int64_t targetElements) {
10331027 SmallVector<int64_t > extractShape;
10341028 int64_t remainingElements = targetElements;
10351029
1036- // Build extract shape from innermost dimension outward to ensure contiguity
1030+ // Build extract shape from innermost dimension outward to ensure contiguity.
10371031 for (int i = sourceShape.size () - 1 ; i >= 0 && remainingElements > 1 ; --i) {
10381032 int64_t takeFromDim = std::min (remainingElements, sourceShape[i]);
10391033 extractShape.insert (extractShape.begin (), takeFromDim);
10401034
1041- if (remainingElements % takeFromDim != 0 ) {
1042- return std::nullopt ; // Not evenly divisible
1043- }
1035+ if (remainingElements % takeFromDim != 0 )
1036+ return std::nullopt ; // Not evenly divisible.
10441037 remainingElements /= takeFromDim;
10451038 }
10461039
1047- // Fill remaining dimensions with 1
1048- while (extractShape.size () < sourceShape.size ()) {
1040+ // Fill remaining dimensions with 1.
1041+ while (extractShape.size () < sourceShape.size ())
10491042 extractShape.insert (extractShape.begin (), 1 );
1050- }
10511043
1052- if (ShapedType::getNumElements (extractShape) != targetElements) {
1044+ if (ShapedType::getNumElements (extractShape) != targetElements)
10531045 return std::nullopt ;
1054- }
10551046
10561047 return extractShape;
10571048}
10581049
1059- // Convert result offsets to source offsets via linear position
1050+ // Convert result offsets to source offsets via linear position.
10601051static SmallVector<int64_t >
10611052calculateSourceOffsets (ArrayRef<int64_t > resultOffsets,
10621053 ArrayRef<int64_t > sourceStrides,
10631054 ArrayRef<int64_t > resultStrides) {
1064- // Convert result offsets to linear position
1055+ // Convert result offsets to linear position.
10651056 int64_t linearIndex = linearize (resultOffsets, resultStrides);
1066- // Convert linear position to source offsets
1067- SmallVector<int64_t > sourceOffsets = delinearize (linearIndex, sourceStrides);
1068- return sourceOffsets;
1057+ // Convert linear position to source offsets.
1058+ return delinearize (linearIndex, sourceStrides);
10691059}
10701060
10711061// / This pattern unrolls `vector.shape_cast` operations according to the
@@ -1079,7 +1069,7 @@ calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
10791069// / remains a valid vector (and not decompose to scalars). In these cases, the
10801070// / unrolling proceeds as:
10811071// / vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
1082- // / vector.insert_strided_slice
1072+ // / vector.insert_strided_slice.
10831073// /
10841074// / Example:
10851075// / Given a shape cast operation:
@@ -1108,7 +1098,8 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
11081098
11091099 LogicalResult matchAndRewrite (vector::ShapeCastOp shapeCastOp,
11101100 PatternRewriter &rewriter) const override {
1111- auto targetShape = getTargetShape (options, shapeCastOp);
1101+ std::optional<SmallVector<int64_t >> targetShape =
1102+ getTargetShape (options, shapeCastOp);
11121103 if (!targetShape)
11131104 return failure ();
11141105
@@ -1117,26 +1108,24 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
11171108 ArrayRef<int64_t > sourceShape = sourceType.getShape ();
11181109 ArrayRef<int64_t > resultShape = resultType.getShape ();
11191110
1120- if (!isContiguousExtract (*targetShape, resultShape)) {
1111+ if (!isContiguousExtract (*targetShape, resultShape))
11211112 return rewriter.notifyMatchFailure (shapeCastOp,
11221113 " Only supports cases where contiguous "
11231114 " extraction is possible" );
1124- }
11251115
11261116 int64_t targetElements = ShapedType::getNumElements (*targetShape);
11271117
1128- // Calculate the shape to extract from source
1129- auto extractShape =
1118+ // Calculate the shape to extract from source.
1119+ std::optional<SmallVector< int64_t >> extractShape =
11301120 calculateSourceExtractShape (sourceShape, targetElements);
1131- if (!extractShape) {
1121+ if (!extractShape)
11321122 return rewriter.notifyMatchFailure (
11331123 shapeCastOp,
11341124 " cannot extract target number of elements contiguously from source" );
1135- }
11361125
11371126 Location loc = shapeCastOp.getLoc ();
11381127
1139- // Create result vector initialized to zero
1128+ // Create result vector initialized to zero.
11401129 Value result = arith::ConstantOp::create (rewriter, loc, resultType,
11411130 rewriter.getZeroAttr (resultType));
11421131
0 commit comments