Skip to content

Commit 11dfecd

Browse files
[draft] Dialect Conversion without Rollback
This commit adds a dialect conversion driver without rollback: `OneShotDialectConversionDriver` The new driver reuses some functionality of the greedy pattern rewrite driver. Just a proof of concept, code is not polished yet. `OneShotConversionPatternRewriter` is a rewriter that materializes all IR changes immediately.
1 parent 6f5ef38 commit 11dfecd

File tree

7 files changed

+336
-47
lines changed

7 files changed

+336
-47
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ class TypeConverter {
247247
/// Attempts a 1-1 type conversion, expecting the result type to be
248248
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
249249
/// and a null type on conversion or cast failure.
250-
template <typename TargetType> TargetType convertType(Type t) const {
250+
template <typename TargetType>
251+
TargetType convertType(Type t) const {
251252
return dyn_cast_or_null<TargetType>(convertType(t));
252253
}
253254

@@ -657,7 +658,7 @@ struct ConversionPatternRewriterImpl;
657658
/// This class implements a pattern rewriter for use with ConversionPatterns. It
658659
/// extends the base PatternRewriter and provides special conversion specific
659660
/// hooks.
660-
class ConversionPatternRewriter final : public PatternRewriter {
661+
class ConversionPatternRewriter : public PatternRewriter {
661662
public:
662663
~ConversionPatternRewriter() override;
663664

@@ -711,6 +712,17 @@ class ConversionPatternRewriter final : public PatternRewriter {
711712
LogicalResult getRemappedValues(ValueRange keys,
712713
SmallVectorImpl<Value> &results);
713714

715+
virtual void setCurrentTypeConverter(const TypeConverter *converter);
716+
717+
virtual const TypeConverter *getCurrentTypeConverter() const;
718+
719+
/// Populate the operands that are used for constructing the adapter into
720+
/// `remapped`.
721+
virtual LogicalResult getAdapterOperands(StringRef valueDiagTag,
722+
std::optional<Location> inputLoc,
723+
ValueRange values,
724+
SmallVector<Value> &remapped);
725+
714726
//===--------------------------------------------------------------------===//
715727
// PatternRewriter Hooks
716728
//===--------------------------------------------------------------------===//
@@ -755,6 +767,14 @@ class ConversionPatternRewriter final : public PatternRewriter {
755767
/// Return a reference to the internal implementation.
756768
detail::ConversionPatternRewriterImpl &getImpl();
757769

770+
protected:
771+
/// Protected constructor for `OneShotConversionPatternRewriter`. Does not
772+
/// initialize `impl`.
773+
explicit ConversionPatternRewriter(MLIRContext *ctx);
774+
775+
// Hide unsupported pattern rewriter API.
776+
using OpBuilder::setListener;
777+
758778
private:
759779
// Allow OperationConverter to construct new rewriters.
760780
friend struct OperationConverter;
@@ -765,9 +785,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
765785
explicit ConversionPatternRewriter(MLIRContext *ctx,
766786
const ConversionConfig &config);
767787

768-
// Hide unsupported pattern rewriter API.
769-
using OpBuilder::setListener;
770-
771788
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
772789
};
773790

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
namespace mlir {
2020

21+
class ConversionTarget;
22+
2123
/// This enum controls which ops are put on the worklist during a greedy
2224
/// pattern rewrite.
2325
enum class GreedyRewriteStrictness {
@@ -188,6 +190,10 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops,
188190
GreedyRewriteConfig config = GreedyRewriteConfig(),
189191
bool *changed = nullptr, bool *allErased = nullptr);
190192

193+
LogicalResult
194+
applyPartialOneShotConversion(Operation *op, const ConversionTarget &target,
195+
const FrozenRewritePatternSet &patterns);
196+
191197
} // namespace mlir
192198

193199
#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/PatternMatch.h"
1616
#include "mlir/Pass/Pass.h"
1717
#include "mlir/Transforms/DialectConversion.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1819
#include <memory>
1920
#include <type_traits>
2021

@@ -1346,8 +1347,8 @@ void ConvertComplexToStandardPass::runOnOperation() {
13461347
ConversionTarget target(getContext());
13471348
target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
13481349
target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1349-
if (failed(
1350-
applyPartialConversion(getOperation(), target, std::move(patterns))))
1350+
if (failed(applyPartialOneShotConversion(getOperation(), target,
1351+
std::move(patterns))))
13511352
signalPassFailure();
13521353
}
13531354
} // namespace

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/TypeUtilities.h"
2626
#include "mlir/IR/Value.h"
2727
#include "mlir/Pass/Pass.h"
28+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2829
#include "llvm/Support/Debug.h"
2930
#include "llvm/Support/ErrorHandling.h"
3031
#include "llvm/Support/raw_ostream.h"
@@ -475,9 +476,10 @@ struct ConvertNVGPUToNVVMPass
475476
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
476477
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
477478
converter, patterns, target);
478-
if (failed(applyPartialConversion(getOperation(), target,
479-
std::move(patterns))))
479+
if (failed(applyPartialOneShotConversion(getOperation(), target,
480+
std::move(patterns))))
480481
signalPassFailure();
482+
// applyPartialConversion
481483
}
482484
};
483485

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,6 +1633,9 @@ ConversionPatternRewriter::ConversionPatternRewriter(
16331633
setListener(impl.get());
16341634
}
16351635

1636+
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
1637+
: PatternRewriter(ctx), impl(nullptr) {}
1638+
16361639
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
16371640

16381641
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
@@ -1819,6 +1822,22 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
18191822
return *impl;
18201823
}
18211824

1825+
void ConversionPatternRewriter::setCurrentTypeConverter(
1826+
const TypeConverter *converter) {
1827+
impl->currentTypeConverter = converter;
1828+
}
1829+
1830+
const TypeConverter *
1831+
ConversionPatternRewriter::getCurrentTypeConverter() const {
1832+
return impl->currentTypeConverter;
1833+
}
1834+
1835+
LogicalResult ConversionPatternRewriter::getAdapterOperands(
1836+
StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1837+
SmallVector<Value> &remapped) {
1838+
return impl->remapValues(valueDiagTag, inputLoc, *this, values, remapped);
1839+
}
1840+
18221841
//===----------------------------------------------------------------------===//
18231842
// ConversionPattern
18241843
//===----------------------------------------------------------------------===//
@@ -1827,16 +1846,18 @@ LogicalResult
18271846
ConversionPattern::matchAndRewrite(Operation *op,
18281847
PatternRewriter &rewriter) const {
18291848
auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1830-
auto &rewriterImpl = dialectRewriter.getImpl();
18311849

18321850
// Track the current conversion pattern type converter in the rewriter.
1833-
llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1834-
getTypeConverter());
1851+
const TypeConverter *currentTypeConverter =
1852+
dialectRewriter.getCurrentTypeConverter();
1853+
auto resetTypeConverter = llvm::make_scope_exit(
1854+
[&] { dialectRewriter.setCurrentTypeConverter(currentTypeConverter); });
1855+
dialectRewriter.setCurrentTypeConverter(getTypeConverter());
18351856

18361857
// Remap the operands of the operation.
1837-
SmallVector<Value, 4> operands;
1838-
if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1839-
op->getOperands(), operands))) {
1858+
SmallVector<Value> operands;
1859+
if (failed(dialectRewriter.getAdapterOperands("operand", op->getLoc(),
1860+
op->getOperands(), operands))) {
18401861
return failure();
18411862
}
18421863
return matchAndRewrite(op, operands, dialectRewriter);

0 commit comments

Comments
 (0)