Skip to content

Commit 73512fd

Browse files
committed
Address feedback
1 parent cd8b818 commit 73512fd

File tree

1 file changed

+24
-35
lines changed

1 file changed

+24
-35
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 24 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,67 +1005,57 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
10051005

10061006
static bool isContiguousExtract(ArrayRef<int64_t> targetShape,
10071007
ArrayRef<int64_t> resultShape) {
1008-
if (targetShape.size() > resultShape.size()) {
1008+
if (targetShape.size() > resultShape.size())
10091009
return false;
1010-
}
10111010

10121011
size_t rankDiff = resultShape.size() - targetShape.size();
10131012
// Inner dimensions must match exactly & total resultElements should be
10141013
// evenly divisible by targetElements.
1015-
for (size_t i = 1; i < targetShape.size(); ++i) {
1016-
if (targetShape[i] != resultShape[rankDiff + i]) {
1017-
return false;
1018-
}
1019-
}
1014+
if (!llvm::equal(targetShape.drop_front(),
1015+
resultShape.drop_front(rankDiff + 1)))
1016+
return false;
10201017

10211018
int64_t targetElements = ShapedType::getNumElements(targetShape);
10221019
int64_t resultElements = ShapedType::getNumElements(resultShape);
1023-
if (resultElements % targetElements != 0) {
1024-
return false;
1025-
}
1026-
return true;
1020+
return resultElements % targetElements == 0;
10271021
}
10281022

1029-
// Calculate the shape to extract from source
1023+
// Calculate the shape to extract from source.
10301024
static std::optional<SmallVector<int64_t>>
10311025
calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
10321026
int64_t targetElements) {
10331027
SmallVector<int64_t> extractShape;
10341028
int64_t remainingElements = targetElements;
10351029

1036-
// Build extract shape from innermost dimension outward to ensure contiguity
1030+
// Build extract shape from innermost dimension outward to ensure contiguity.
10371031
for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
10381032
int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
10391033
extractShape.insert(extractShape.begin(), takeFromDim);
10401034

1041-
if (remainingElements % takeFromDim != 0) {
1042-
return std::nullopt; // Not evenly divisible
1043-
}
1035+
if (remainingElements % takeFromDim != 0)
1036+
return std::nullopt; // Not evenly divisible.
10441037
remainingElements /= takeFromDim;
10451038
}
10461039

1047-
// Fill remaining dimensions with 1
1048-
while (extractShape.size() < sourceShape.size()) {
1040+
// Fill remaining dimensions with 1.
1041+
while (extractShape.size() < sourceShape.size())
10491042
extractShape.insert(extractShape.begin(), 1);
1050-
}
10511043

1052-
if (ShapedType::getNumElements(extractShape) != targetElements) {
1044+
if (ShapedType::getNumElements(extractShape) != targetElements)
10531045
return std::nullopt;
1054-
}
10551046

10561047
return extractShape;
10571048
}
10581049

1059-
// Convert result offsets to source offsets via linear position
1050+
// Convert result offsets to source offsets via linear position.
10601051
static SmallVector<int64_t>
10611052
calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
10621053
ArrayRef<int64_t> sourceStrides,
10631054
ArrayRef<int64_t> resultStrides) {
1064-
// Convert result offsets to linear position
1055+
// Convert result offsets to linear position.
10651056
int64_t linearIndex = linearize(resultOffsets, resultStrides);
1066-
// Convert linear position to source offsets
1067-
SmallVector<int64_t> sourceOffsets = delinearize(linearIndex, sourceStrides);
1068-
return sourceOffsets;
1057+
// Convert linear position to source offsets.
1058+
return delinearize(linearIndex, sourceStrides);
10691059
}
10701060

10711061
/// This pattern unrolls `vector.shape_cast` operations according to the
@@ -1079,7 +1069,7 @@ calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
10791069
/// remains a valid vector (and not decompose to scalars). In these cases, the
10801070
/// unrolling proceeds as:
10811071
/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
1082-
/// vector.insert_strided_slice
1072+
/// vector.insert_strided_slice.
10831073
///
10841074
/// Example:
10851075
/// Given a shape cast operation:
@@ -1108,7 +1098,8 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
11081098

11091099
LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
11101100
PatternRewriter &rewriter) const override {
1111-
auto targetShape = getTargetShape(options, shapeCastOp);
1101+
std::optional<SmallVector<int64_t>> targetShape =
1102+
getTargetShape(options, shapeCastOp);
11121103
if (!targetShape)
11131104
return failure();
11141105

@@ -1117,26 +1108,24 @@ struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
11171108
ArrayRef<int64_t> sourceShape = sourceType.getShape();
11181109
ArrayRef<int64_t> resultShape = resultType.getShape();
11191110

1120-
if (!isContiguousExtract(*targetShape, resultShape)) {
1111+
if (!isContiguousExtract(*targetShape, resultShape))
11211112
return rewriter.notifyMatchFailure(shapeCastOp,
11221113
"Only supports cases where contiguous "
11231114
"extraction is possible");
1124-
}
11251115

11261116
int64_t targetElements = ShapedType::getNumElements(*targetShape);
11271117

1128-
// Calculate the shape to extract from source
1129-
auto extractShape =
1118+
// Calculate the shape to extract from source.
1119+
std::optional<SmallVector<int64_t>> extractShape =
11301120
calculateSourceExtractShape(sourceShape, targetElements);
1131-
if (!extractShape) {
1121+
if (!extractShape)
11321122
return rewriter.notifyMatchFailure(
11331123
shapeCastOp,
11341124
"cannot extract target number of elements contiguously from source");
1135-
}
11361125

11371126
Location loc = shapeCastOp.getLoc();
11381127

1139-
// Create result vector initialized to zero
1128+
// Create result vector initialized to zero.
11401129
Value result = arith::ConstantOp::create(rewriter, loc, resultType,
11411130
rewriter.getZeroAttr(resultType));
11421131

0 commit comments

Comments
 (0)