Skip to content

Commit 371998e

Browse files
[mlir][Transforms] One-Shot Dialect Conversion
1 parent 160f5ca commit 371998e

30 files changed

+501
-104
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> {
5252
"Test conversion patterns of only the specified dialects">,
5353
Option<"useDynamic", "dynamic", "bool", "false",
5454
"Use op conversion attributes to configure the conversion">,
55+
Option<"allowPatternRollback", "allow-pattern-rollback", "bool", "true",
56+
"Experimental performance flag to disallow pattern rollback">
5557
];
5658
}
5759

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,16 +1231,16 @@ struct ConversionConfig {
12311231
/// 2. Pattern produces IR (in-place modification or new IR) that is illegal
12321232
/// and cannot be legalized by subsequent foldings / pattern applications.
12331233
///
1234-
/// If set to "false", the conversion driver will produce an LLVM fatal error
1235-
/// instead of rolling back IR modifications. Moreover, in case of a failed
1236-
/// conversion, the original IR is not restored. The resulting IR may be a
1237-
/// mix of original and rewritten IR. (Same as a failed greedy pattern
1238-
/// rewrite.)
1234+
/// Experimental: If set to "false", the conversion driver will produce an
1235+
/// LLVM fatal error instead of rolling back IR modifications. Moreover, in
1236+
/// case of a failed conversion, the original IR is not restored. The
1237+
/// resulting IR may be a mix of original and rewritten IR. (Same as a failed
1238+
/// greedy pattern rewrite.) Use MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
1239+
/// with ASAN to detect invalid pattern API usage.
12391240
///
1240-
/// Note: This flag was added in preparation of the One-Shot Dialect
1241-
/// Conversion refactoring, which will remove the ability to roll back IR
1242-
/// modifications from the conversion driver. Use this flag to ensure that
1243-
/// your patterns do not trigger any IR rollbacks. For details, see
1241+
/// When pattern rollback is disabled, the conversion driver has to maintain
1242+
/// less internal state. This is more efficient, but not supported by all
1243+
/// lowering patterns. For details, see
12441244
/// https://discourse.llvm.org/t/rfc-a-new-one-shot-dialect-conversion-driver/79083.
12451245
bool allowPatternRollback = true;
12461246
};

mlir/lib/Conversion/ConvertToLLVM/ConvertToLLVMPass.cpp

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@ namespace {
3131
class ConvertToLLVMPassInterface {
3232
public:
3333
ConvertToLLVMPassInterface(MLIRContext *context,
34-
ArrayRef<std::string> filterDialects);
34+
ArrayRef<std::string> filterDialects,
35+
bool allowPatternRollback = true);
3536
virtual ~ConvertToLLVMPassInterface() = default;
3637

3738
/// Get the dependent dialects used by `convert-to-llvm`.
@@ -60,6 +61,9 @@ class ConvertToLLVMPassInterface {
6061
MLIRContext *context;
6162
/// List of dialects names to use as filters.
6263
ArrayRef<std::string> filterDialects;
64+
/// An experimental flag to disallow pattern rollback. This is more efficient
65+
/// but not supported by all lowering patterns.
66+
bool allowPatternRollback;
6367
};
6468

6569
/// This DialectExtension can be attached to the context, which will invoke the
@@ -128,7 +132,9 @@ struct StaticConvertToLLVM : public ConvertToLLVMPassInterface {
128132

129133
/// Apply the conversion driver.
130134
LogicalResult transform(Operation *op, AnalysisManager manager) const final {
131-
if (failed(applyPartialConversion(op, *target, *patterns)))
135+
ConversionConfig config;
136+
config.allowPatternRollback = allowPatternRollback;
137+
if (failed(applyPartialConversion(op, *target, *patterns, config)))
132138
return failure();
133139
return success();
134140
}
@@ -179,7 +185,9 @@ struct DynamicConvertToLLVM : public ConvertToLLVMPassInterface {
179185
patterns);
180186

181187
// Apply the conversion.
182-
if (failed(applyPartialConversion(op, target, std::move(patterns))))
188+
ConversionConfig config;
189+
config.allowPatternRollback = allowPatternRollback;
190+
if (failed(applyPartialConversion(op, target, std::move(patterns), config)))
183191
return failure();
184192
return success();
185193
}
@@ -206,9 +214,11 @@ class ConvertToLLVMPass
206214
std::shared_ptr<ConvertToLLVMPassInterface> impl;
207215
// Choose the pass implementation.
208216
if (useDynamic)
209-
impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects);
217+
impl = std::make_shared<DynamicConvertToLLVM>(context, filterDialects,
218+
allowPatternRollback);
210219
else
211-
impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects);
220+
impl = std::make_shared<StaticConvertToLLVM>(context, filterDialects,
221+
allowPatternRollback);
212222
if (failed(impl->initialize()))
213223
return failure();
214224
this->impl = impl;
@@ -228,8 +238,10 @@ class ConvertToLLVMPass
228238
//===----------------------------------------------------------------------===//
229239

230240
ConvertToLLVMPassInterface::ConvertToLLVMPassInterface(
231-
MLIRContext *context, ArrayRef<std::string> filterDialects)
232-
: context(context), filterDialects(filterDialects) {}
241+
MLIRContext *context, ArrayRef<std::string> filterDialects,
242+
bool allowPatternRollback)
243+
: context(context), filterDialects(filterDialects),
244+
allowPatternRollback(allowPatternRollback) {}
233245

234246
void ConvertToLLVMPassInterface::getDependentDialects(
235247
DialectRegistry &registry) {

mlir/lib/Dialect/SparseTensor/Transforms/SparseTensorConversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -730,9 +730,9 @@ class SparseTensorCompressConverter : public OpConversionPattern<CompressOp> {
730730
{tensor, lvlCoords, values, filled, added, count},
731731
EmitCInterface::On);
732732
Operation *parent = getTop(op);
733+
rewriter.setInsertionPointAfter(parent);
733734
rewriter.replaceOp(op, adaptor.getTensor());
734735
// Deallocate the buffers on exit of the loop nest.
735-
rewriter.setInsertionPointAfter(parent);
736736
memref::DeallocOp::create(rewriter, loc, values);
737737
memref::DeallocOp::create(rewriter, loc, filled);
738738
memref::DeallocOp::create(rewriter, loc, added);

0 commit comments

Comments
 (0)