@@ -811,7 +811,11 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
811811};
812812
813813struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
814- using OpRewritePattern::OpRewritePattern;
814+ UnrollToElements (MLIRContext *context,
815+ const vector::UnrollVectorOptions &options,
816+ PatternBenefit benefit = 1 )
817+ : OpRewritePattern<vector::ToElementsOp>(context, benefit),
818+ options (options) {}
815819
816820 LogicalResult matchAndRewrite (vector::ToElementsOp op,
817821 PatternRewriter &rewriter) const override {
@@ -833,6 +837,9 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
833837 rewriter.replaceOp (op, results);
834838 return success ();
835839 }
840+
841+ private:
842+ vector::UnrollVectorOptions options;
836843};
837844
838845// / Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
@@ -852,7 +859,11 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
852859// / When applied exhaustively, this will produce a sequence of 1-d from_elements
853860// / ops.
854861struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
855- using OpRewritePattern::OpRewritePattern;
862+ UnrollFromElements (MLIRContext *context,
863+ const vector::UnrollVectorOptions &options,
864+ PatternBenefit benefit = 1 )
865+ : OpRewritePattern<vector::FromElementsOp>(context, benefit),
866+ options (options) {}
856867
857868 LogicalResult matchAndRewrite (vector::FromElementsOp op,
858869 PatternRewriter &rewriter) const override {
@@ -870,29 +881,32 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
870881
871882 return unrollVectorOp (op, rewriter, unrollFromElementsFn);
872883 }
884+
885+ private:
886+ vector::UnrollVectorOptions options;
873887};
874888
875889} // namespace
876890
877891void mlir::vector::populateVectorUnrollPatterns (
878892 RewritePatternSet &patterns, const UnrollVectorOptions &options,
879893 PatternBenefit benefit) {
880- patterns.add <UnrollFromElements, UnrollToElements>(patterns.getContext (),
881- benefit);
882894 patterns.add <UnrollTransferReadPattern, UnrollTransferWritePattern,
883895 UnrollContractionPattern, UnrollElementwisePattern,
884896 UnrollReductionPattern, UnrollMultiReductionPattern,
885897 UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
886- UnrollStorePattern, UnrollBroadcastPattern>(
887- patterns.getContext (), options, benefit);
898+ UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
899+ UnrollToElements>( patterns.getContext (), options, benefit);
888900}
889901
890902void mlir::vector::populateVectorToElementsUnrollPatterns (
891903 RewritePatternSet &patterns, PatternBenefit benefit) {
892- patterns.add <UnrollToElements>(patterns.getContext (), benefit);
904+ patterns.add <UnrollToElements>(patterns.getContext (), UnrollVectorOptions (),
905+ benefit);
893906}
894907
895908void mlir::vector::populateVectorFromElementsUnrollPatterns (
896909 RewritePatternSet &patterns, PatternBenefit benefit) {
897- patterns.add <UnrollFromElements>(patterns.getContext (), benefit);
910+ patterns.add <UnrollFromElements>(patterns.getContext (), UnrollVectorOptions (),
911+ benefit);
898912}
0 commit comments