@@ -65,6 +65,9 @@ class ConvertToLLVMPass
6565 : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
6666 std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
6767 interfaces;
68+ std::shared_ptr<const FrozenRewritePatternSet> patterns;
69+ std::shared_ptr<const ConversionTarget> target;
70+ std::shared_ptr<const LLVMTypeConverter> typeConverter;
6871
6972public:
7073 using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -74,8 +77,22 @@ class ConvertToLLVMPass
7477 }
7578
7679 LogicalResult initialize (MLIRContext *context) final {
77- auto interfaces =
78- std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
80+ std::shared_ptr<SmallVector<ConvertToLLVMPatternInterface *>> interfaces;
81+ std::shared_ptr<ConversionTarget> target;
82+ std::shared_ptr<LLVMTypeConverter> typeConverter;
83+ RewritePatternSet tempPatterns (context);
84+
85+ // Only collect the interfaces if `useConversionAttrs=true` as everything
86+ // else must be initialized in `runOnOperation`.
87+ if (useConversionAttrs) {
88+ interfaces =
89+ std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
90+ } else {
91+ target = std::make_shared<ConversionTarget>(*context);
92+ target->addLegalDialect <LLVM::LLVMDialect>();
93+ typeConverter = std::make_shared<LLVMTypeConverter>(context);
94+ }
95+
7996 if (!filterDialects.empty ()) {
8097 // Test mode: Populate only patterns from the specified dialects. Produce
8198 // an error if the dialect is not loaded or does not implement the
@@ -90,7 +107,12 @@ class ConvertToLLVMPass
90107 return emitError (UnknownLoc::get (context))
91108 << " dialect does not implement ConvertToLLVMPatternInterface: "
92109 << dialectName << " \n " ;
93- interfaces->push_back (iface);
110+ if (useConversionAttrs) {
111+ interfaces->push_back (iface);
112+ continue ;
113+ }
114+ iface->populateConvertToLLVMConversionPatterns (*target, *typeConverter,
115+ tempPatterns);
94116 }
95117 } else {
96118 // Normal mode: Populate all patterns from all dialects that implement the
@@ -101,15 +123,34 @@ class ConvertToLLVMPass
101123 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
102124 if (!iface)
103125 continue ;
104- interfaces->push_back (iface);
126+ if (useConversionAttrs) {
127+ interfaces->push_back (iface);
128+ continue ;
129+ }
130+ iface->populateConvertToLLVMConversionPatterns (*target, *typeConverter,
131+ tempPatterns);
105132 }
106133 }
107134
108- this ->interfaces = interfaces;
135+ if (useConversionAttrs) {
136+ this ->interfaces = interfaces;
137+ } else {
138+ this ->patterns =
139+ std::make_unique<FrozenRewritePatternSet>(std::move (tempPatterns));
140+ this ->target = target;
141+ this ->typeConverter = typeConverter;
142+ }
109143 return success ();
110144 }
111145
112146 void runOnOperation () final {
147+ // Fast path:
148+ if (!useConversionAttrs) {
149+ if (failed (applyPartialConversion (getOperation (), *target, *patterns)))
150+ signalPassFailure ();
151+ return ;
152+ }
153+ // Slow path with conversion attributes.
113154 MLIRContext *context = &getContext ();
114155 RewritePatternSet patterns (context);
115156 ConversionTarget target (*context);
0 commit comments