Skip to content

Commit c054f16

Browse files
committed
Adds UnrollVectorOptions to ToElements and ForElements patterns
1 parent 70a667e commit c054f16

File tree

1 file changed

+22
-8
lines changed

1 file changed

+22
-8
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,11 @@ struct UnrollBroadcastPattern : public OpRewritePattern<vector::BroadcastOp> {
811811
};
812812

813813
struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
814-
using OpRewritePattern::OpRewritePattern;
814+
UnrollToElements(MLIRContext *context,
815+
const vector::UnrollVectorOptions &options,
816+
PatternBenefit benefit = 1)
817+
: OpRewritePattern<vector::ToElementsOp>(context, benefit),
818+
options(options) {}
815819

816820
LogicalResult matchAndRewrite(vector::ToElementsOp op,
817821
PatternRewriter &rewriter) const override {
@@ -833,6 +837,9 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
833837
rewriter.replaceOp(op, results);
834838
return success();
835839
}
840+
841+
private:
842+
vector::UnrollVectorOptions options;
836843
};
837844

838845
/// Unrolls 2 or more dimensional `vector.from_elements` ops by unrolling the
@@ -852,7 +859,11 @@ struct UnrollToElements final : public OpRewritePattern<vector::ToElementsOp> {
852859
/// When applied exhaustively, this will produce a sequence of 1-d from_elements
853860
/// ops.
854861
struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
855-
using OpRewritePattern::OpRewritePattern;
862+
UnrollFromElements(MLIRContext *context,
863+
const vector::UnrollVectorOptions &options,
864+
PatternBenefit benefit = 1)
865+
: OpRewritePattern<vector::FromElementsOp>(context, benefit),
866+
options(options) {}
856867

857868
LogicalResult matchAndRewrite(vector::FromElementsOp op,
858869
PatternRewriter &rewriter) const override {
@@ -870,29 +881,32 @@ struct UnrollFromElements : OpRewritePattern<vector::FromElementsOp> {
870881

871882
return unrollVectorOp(op, rewriter, unrollFromElementsFn);
872883
}
884+
885+
private:
886+
vector::UnrollVectorOptions options;
873887
};
874888

875889
} // namespace
876890

877891
void mlir::vector::populateVectorUnrollPatterns(
878892
RewritePatternSet &patterns, const UnrollVectorOptions &options,
879893
PatternBenefit benefit) {
880-
patterns.add<UnrollFromElements, UnrollToElements>(patterns.getContext(),
881-
benefit);
882894
patterns.add<UnrollTransferReadPattern, UnrollTransferWritePattern,
883895
UnrollContractionPattern, UnrollElementwisePattern,
884896
UnrollReductionPattern, UnrollMultiReductionPattern,
885897
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
886-
UnrollStorePattern, UnrollBroadcastPattern>(
887-
patterns.getContext(), options, benefit);
898+
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
899+
UnrollToElements>(patterns.getContext(), options, benefit);
888900
}
889901

890902
void mlir::vector::populateVectorToElementsUnrollPatterns(
891903
RewritePatternSet &patterns, PatternBenefit benefit) {
892-
patterns.add<UnrollToElements>(patterns.getContext(), benefit);
904+
patterns.add<UnrollToElements>(patterns.getContext(), UnrollVectorOptions(),
905+
benefit);
893906
}
894907

895908
void mlir::vector::populateVectorFromElementsUnrollPatterns(
896909
RewritePatternSet &patterns, PatternBenefit benefit) {
897-
patterns.add<UnrollFromElements>(patterns.getContext(), benefit);
910+
patterns.add<UnrollFromElements>(patterns.getContext(), UnrollVectorOptions(),
911+
benefit);
898912
}

0 commit comments

Comments
 (0)