@@ -31,7 +31,8 @@ namespace {
31
31
class ConvertToLLVMPassInterface {
32
32
public:
33
33
ConvertToLLVMPassInterface (MLIRContext *context,
34
- ArrayRef<std::string> filterDialects);
34
+ ArrayRef<std::string> filterDialects,
35
+ bool allowPatternRollback = true );
35
36
virtual ~ConvertToLLVMPassInterface () = default ;
36
37
37
38
// / Get the dependent dialects used by `convert-to-llvm`.
@@ -60,6 +61,9 @@ class ConvertToLLVMPassInterface {
60
61
MLIRContext *context;
61
62
// / List of dialects names to use as filters.
62
63
ArrayRef<std::string> filterDialects;
64
+ // / An experimental flag to disallow pattern rollback. This is more efficient
65
+ // / but not supported by all lowering patterns.
66
+ bool allowPatternRollback;
63
67
};
64
68
65
69
// / This DialectExtension can be attached to the context, which will invoke the
@@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
128
132
129
133
// / Apply the conversion driver.
130
134
LogicalResult transform (Operation *op, AnalysisManager manager) const final {
131
- if (failed (applyPartialConversion (op, *target, *patterns)))
135
+ ConversionConfig config;
136
+ config.allowPatternRollback = allowPatternRollback;
137
+ if (failed (applyPartialConversion (op, *target, *patterns, config)))
132
138
return failure ();
133
139
return success ();
134
140
}
@@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
179
185
patterns);
180
186
181
187
// Apply the conversion.
182
- if (failed (applyPartialConversion (op, target, std::move (patterns))))
188
+ ConversionConfig config;
189
+ config.allowPatternRollback = allowPatternRollback;
190
+ if (failed (applyPartialConversion (op, target, std::move (patterns), config)))
183
191
return failure ();
184
192
return success ();
185
193
}
@@ -206,9 +214,11 @@ class ConvertToLLVMPass
206
214
std::shared_ptr<ConvertToLLVMPassInterface> impl;
207
215
// Choose the pass implementation.
208
216
if (useDynamic)
209
- impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
217
+ impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
218
+ allowPatternRollback);
210
219
else
211
- impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
220
+ impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
221
+ allowPatternRollback);
212
222
if (failed (impl->initialize ()))
213
223
return failure ();
214
224
this ->impl = impl;
@@ -228,8 +238,10 @@ class ConvertToLLVMPass
228
238
// ===----------------------------------------------------------------------===//
229
239
230
240
ConvertToLLVMPassInterface::ConvertToLLVMPassInterface (
231
- MLIRContext *context, ArrayRef<std::string> filterDialects)
232
- : context(context), filterDialects(filterDialects) {}
241
+ MLIRContext *context, ArrayRef<std::string> filterDialects,
242
+ bool allowPatternRollback)
243
+ : context(context), filterDialects(filterDialects),
244
+ allowPatternRollback(allowPatternRollback) {}
233
245
234
246
void ConvertToLLVMPassInterface::getDependentDialects (
235
247
DialectRegistry ®istry) {
0 commit comments