Skip to content

Commit e1f9605

Browse files
committed
Copy over UnrollToElements
1 parent c891a27 commit e1f9605

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,12 @@ void populateVectorUnrollPatterns(RewritePatternSet &patterns,
322322
const UnrollVectorOptions &options,
323323
PatternBenefit benefit = 1);
324324

325+
/// Populate the pattern set with the following patterns:
326+
///
327+
/// [UnrollToElements]
328+
void populateVectorToElementsUnrollPatterns(RewritePatternSet &patterns,
329+
PatternBenefit benefit = 1);
330+
325331
/// Collect a set of leading one dimension removal patterns.
326332
///
327333
/// These patterns insert vector.shape_cast to remove leading one dimensions

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -810,17 +810,47 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
810810
vector::UnrollVectorOptions options;
811811
};
812812

813+
struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
814+
using OpRewritePattern::OpRewritePattern;
815+
816+
LogicalResult matchAndRewrite(vector::ToElementsOp op,
817+
PatternRewriter &rewriter) const override {
818+
819+
TypedValue<VectorType> source = op.getSource();
820+
FailureOr<SmallVector<Value>> result =
821+
vector::unrollVectorValue(source, rewriter);
822+
if (failed(result)) {
823+
return failure();
824+
}
825+
SmallVector<Value> vectors = *result;
826+
827+
SmallVector<Value> results;
828+
for (const Value &vector : vectors) {
829+
auto subElements =
830+
vector::ToElementsOp::create(rewriter, op.getLoc(), vector);
831+
llvm::append_range(results, subElements.getResults());
832+
}
833+
rewriter.replaceOp(op, results);
834+
return success();
835+
}
836+
};
837+
813838
} // namespace
814839

815840
void mlir::vector::populateVectorUnrollPatterns(
816841
RewritePatternSet &patterns, const UnrollVectorOptions &options,
817842
PatternBenefit benefit) {
818-
populateVectorToElementsLoweringPatterns(patterns);
819843
populateVectorFromElementsLoweringPatterns(patterns);
844+
patterns.add<UnrollToElements>(patterns.getContext(), benefit);
820845
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
821846
UnrollContractionPattern, UnrollElementwisePattern,
822847
UnrollReductionPattern, UnrollMultiReductionPattern,
823848
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
824849
UnrollStorePattern, UnrollBroadcastPattern>(
825850
patterns.getContext(), options, benefit);
826851
}
852+
853+
void mlir::vector::populateVectorToElementsUnrollPatterns(
854+
RewritePatternSet &patterns, PatternBenefit benefit) {
855+
patterns.add<UnrollToElements>(patterns.getContext(), benefit);
856+
}

0 commit comments

Comments
 (0)