Skip to content

Commit 5c1a19d

Browse files
committed
Spell out the type for getTargetShape in the file
1 parent 521aec0 commit 5c1a19d

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ struct UnrollTransferReadPattern
161161
return failure();
162162
if (readOp.getMask())
163163
return failure();
164-
auto targetShape = getTargetShape(options, readOp);
164+
std::optional<SmallVector<int64_t>> targetShape =
165+
getTargetShape(options, readOp);
165166
if (!targetShape)
166167
return failure();
167168
auto sourceVectorType = readOp.getVectorType();
@@ -216,7 +217,8 @@ struct UnrollTransferWritePattern
216217

217218
if (writeOp.getMask())
218219
return failure();
219-
auto targetShape = getTargetShape(options, writeOp);
220+
std::optional<SmallVector<int64_t>> targetShape =
221+
getTargetShape(options, writeOp);
220222
if (!targetShape)
221223
return failure();
222224
auto sourceVectorType = writeOp.getVectorType();
@@ -287,7 +289,8 @@ struct UnrollContractionPattern
287289

288290
LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
289291
PatternRewriter &rewriter) const override {
290-
auto targetShape = getTargetShape(options, contractOp);
292+
std::optional<SmallVector<int64_t>> targetShape =
293+
getTargetShape(options, contractOp);
291294
if (!targetShape)
292295
return failure();
293296
auto dstVecType = cast<VectorType>(contractOp.getResultType());
@@ -462,7 +465,8 @@ struct UnrollElementwisePattern : public RewritePattern {
462465
PatternRewriter &rewriter) const override {
463466
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
464467
return failure();
465-
auto targetShape = getTargetShape(options, op);
468+
std::optional<SmallVector<int64_t>> targetShape =
469+
getTargetShape(options, op);
466470
if (!targetShape)
467471
return failure();
468472
int64_t targetShapeRank = targetShape->size();
@@ -590,7 +594,8 @@ struct UnrollTransposePattern : public OpRewritePattern<vector::TransposeOp> {
590594
PatternRewriter &rewriter) const override {
591595
if (transposeOp.getResultVectorType().getRank() == 0)
592596
return failure();
593-
auto targetShape = getTargetShape(options, transposeOp);
597+
std::optional<SmallVector<int64_t>> targetShape =
598+
getTargetShape(options, transposeOp);
594599
if (!targetShape)
595600
return failure();
596601
auto originalVectorType = transposeOp.getResultVectorType();
@@ -643,7 +648,8 @@ struct UnrollGatherPattern : public OpRewritePattern<vector::GatherOp> {
643648
VectorType sourceVectorType = gatherOp.getVectorType();
644649
if (sourceVectorType.getRank() == 0)
645650
return failure();
646-
auto targetShape = getTargetShape(options, gatherOp);
651+
std::optional<SmallVector<int64_t>> targetShape =
652+
getTargetShape(options, gatherOp);
647653
if (!targetShape)
648654
return failure();
649655
SmallVector<int64_t> strides(targetShape->size(), 1);
@@ -697,7 +703,8 @@ struct UnrollLoadPattern : public OpRewritePattern<vector::LoadOp> {
697703
PatternRewriter &rewriter) const override {
698704
VectorType vecType = loadOp.getVectorType();
699705

700-
auto targetShape = getTargetShape(options, loadOp);
706+
std::optional<SmallVector<int64_t>> targetShape =
707+
getTargetShape(options, loadOp);
701708
if (!targetShape)
702709
return failure();
703710

@@ -741,7 +748,8 @@ struct UnrollStorePattern : public OpRewritePattern<vector::StoreOp> {
741748
PatternRewriter &rewriter) const override {
742749
VectorType vecType = storeOp.getVectorType();
743750

744-
auto targetShape = getTargetShape(options, storeOp);
751+
std::optional<SmallVector<int64_t>> targetShape =
752+
getTargetShape(options, storeOp);
745753
if (!targetShape)
746754
return failure();
747755

@@ -780,7 +788,8 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
780788

781789
LogicalResult matchAndRewrite(vector::BroadcastOp broadcastOp,
782790
PatternRewriter &rewriter) const override {
783-
auto targetShape = getTargetShape(options, broadcastOp);
791+
std::optional<SmallVector<int64_t>> targetShape =
792+
getTargetShape(options, broadcastOp);
784793
if (!targetShape)
785794
return failure();
786795

@@ -863,7 +872,8 @@ struct ToElementsToTargetShape final
863872

864873
LogicalResult matchAndRewrite(vector::ToElementsOp op,
865874
PatternRewriter &rewriter) const override {
866-
auto targetShape = getTargetShape(options, op);
875+
std::optional<SmallVector<int64_t>> targetShape =
876+
getTargetShape(options, op);
867877
if (!targetShape)
868878
return failure();
869879

0 commit comments

Comments
 (0)