@@ -810,17 +810,47 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
810810 vector::UnrollVectorOptions options;
811811};
812812
813+ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
814+ using OpRewritePattern::OpRewritePattern;
815+
816+ LogicalResult matchAndRewrite (vector::ToElementsOp op,
817+ PatternRewriter &rewriter) const override {
818+
819+ TypedValue<VectorType> source = op.getSource ();
820+ FailureOr<SmallVector<Value>> result =
821+ vector::unrollVectorValue (source, rewriter);
822+ if (failed (result)) {
823+ return failure ();
824+ }
825+ SmallVector<Value> vectors = *result;
826+
827+ SmallVector<Value> results;
828+ for (const Value &vector : vectors) {
829+ auto subElements =
830+ vector::ToElementsOp::create (rewriter, op.getLoc (), vector);
831+ llvm::append_range (results, subElements.getResults ());
832+ }
833+ rewriter.replaceOp (op, results);
834+ return success ();
835+ }
836+ };
837+
813838} // namespace
814839
815840void mlir::vector::populateVectorUnrollPatterns (
816841 RewritePatternSet &patterns, const UnrollVectorOptions &options,
817842 PatternBenefit benefit) {
818- populateVectorToElementsLoweringPatterns (patterns);
819843 populateVectorFromElementsLoweringPatterns (patterns);
844+ patterns.add <UnrollToElements>(patterns.getContext (), benefit);
820845 patterns.add <UnrollTransferReadPattern, UnrollTransferWritePattern,
821846 UnrollContractionPattern, UnrollElementwisePattern,
822847 UnrollReductionPattern, UnrollMultiReductionPattern,
823848 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
824849 UnrollStorePattern, UnrollBroadcastPattern>(
825850 patterns.getContext (), options, benefit);
826851}
852+
853+ void mlir::vector::populateVectorToElementsUnrollPatterns (
854+ RewritePatternSet &patterns, PatternBenefit benefit) {
855+ patterns.add <UnrollToElements>(patterns.getContext (), benefit);
856+ }
0 commit comments