@@ -63,6 +63,9 @@ class ConvertToLLVMPass
6363 : public impl::ConvertToLLVMPassBase<ConvertToLLVMPass> {
6464 std::shared_ptr<const SmallVector<ConvertToLLVMPatternInterface *>>
6565 interfaces;
66+ std::shared_ptr<const FrozenRewritePatternSet> patterns;
67+ std::shared_ptr<const ConversionTarget> target;
68+ std::shared_ptr<const LLVMTypeConverter> typeConverter;
6669
6770public:
6871 using impl::ConvertToLLVMPassBase<ConvertToLLVMPass>::ConvertToLLVMPassBase;
@@ -72,8 +75,22 @@ class ConvertToLLVMPass
7275 }
7376
7477 LogicalResult initialize (MLIRContext *context) final {
75- auto interfaces =
76- std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
78+ std::shared_ptr<SmallVector<ConvertToLLVMPatternInterface *>> interfaces;
79+ std::shared_ptr<ConversionTarget> target;
80+ std::shared_ptr<LLVMTypeConverter> typeConverter;
81+ RewritePatternSet tempPatterns (context);
82+
83+ // Only collect the interfaces if `useConversionAttrs=true` as everything
84+ // else must be initialized in `runOnOperation`.
85+ if (useConversionAttrs) {
86+ interfaces =
87+ std::make_shared<SmallVector<ConvertToLLVMPatternInterface *>>();
88+ } else {
89+ target = std::make_shared<ConversionTarget>(*context);
90+ target->addLegalDialect <LLVM::LLVMDialect>();
91+ typeConverter = std::make_shared<LLVMTypeConverter>(context);
92+ }
93+
7794 if (!filterDialects.empty ()) {
7895 // Test mode: Populate only patterns from the specified dialects. Produce
7996 // an error if the dialect is not loaded or does not implement the
@@ -88,7 +105,12 @@ class ConvertToLLVMPass
88105 return emitError (UnknownLoc::get (context))
89106 << " dialect does not implement ConvertToLLVMPatternInterface: "
90107 << dialectName << " \n " ;
91- interfaces->push_back (iface);
108+ if (useConversionAttrs) {
109+ interfaces->push_back (iface);
110+ continue ;
111+ }
112+ iface->populateConvertToLLVMConversionPatterns (*target, *typeConverter,
113+ tempPatterns);
92114 }
93115 } else {
94116 // Normal mode: Populate all patterns from all dialects that implement the
@@ -99,15 +121,34 @@ class ConvertToLLVMPass
99121 auto *iface = dyn_cast<ConvertToLLVMPatternInterface>(dialect);
100122 if (!iface)
101123 continue ;
102- interfaces->push_back (iface);
124+ if (useConversionAttrs) {
125+ interfaces->push_back (iface);
126+ continue ;
127+ }
128+ iface->populateConvertToLLVMConversionPatterns (*target, *typeConverter,
129+ tempPatterns);
103130 }
104131 }
105132
106- this ->interfaces = interfaces;
133+ if (useConversionAttrs) {
134+ this ->interfaces = interfaces;
135+ } else {
136+ this ->patterns =
137+ std::make_unique<FrozenRewritePatternSet>(std::move (tempPatterns));
138+ this ->target = target;
139+ this ->typeConverter = typeConverter;
140+ }
107141 return success ();
108142 }
109143
110144 void runOnOperation () final {
145+ // Fast path:
146+ if (!useConversionAttrs) {
147+ if (failed (applyPartialConversion (getOperation (), *target, *patterns)))
148+ signalPassFailure ();
149+ return ;
150+ }
151+ // Slow path with conversion attributes.
111152 MLIRContext *context = &getContext ();
112153 RewritePatternSet patterns (context);
113154 ConversionTarget target (*context);
0 commit comments