Skip to content

Commit 6c04a06

Browse files
[draft] Dialect Conversion without Rollback
This commit adds a dialect conversion driver without rollback: `OneShotDialectConversionDriver` The new driver reuses some functionality of the greedy pattern rewrite driver. Just a proof of concept, code is not polished yet. `OneShotConversionPatternRewriter` is a rewriter that materializes all IR changes immediately.
1 parent 6b3e000 commit 6c04a06

File tree

15 files changed

+379
-85
lines changed

15 files changed

+379
-85
lines changed

mlir/include/mlir/Transforms/DialectConversion.h

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ class TypeConverter {
247247
/// Attempts a 1-1 type conversion, expecting the result type to be
248248
/// `TargetType`. Returns the converted type cast to `TargetType` on success,
249249
/// and a null type on conversion or cast failure.
250-
template <typename TargetType> TargetType convertType(Type t) const {
250+
template <typename TargetType>
251+
TargetType convertType(Type t) const {
251252
return dyn_cast_or_null<TargetType>(convertType(t));
252253
}
253254

@@ -657,7 +658,7 @@ struct ConversionPatternRewriterImpl;
657658
/// This class implements a pattern rewriter for use with ConversionPatterns. It
658659
/// extends the base PatternRewriter and provides special conversion specific
659660
/// hooks.
660-
class ConversionPatternRewriter final : public PatternRewriter {
661+
class ConversionPatternRewriter : public PatternRewriter {
661662
public:
662663
~ConversionPatternRewriter() override;
663664

@@ -708,8 +709,18 @@ class ConversionPatternRewriter final : public PatternRewriter {
708709
/// Return the converted values that replace 'keys' with types defined by the
709710
/// type converter of the currently executing pattern. Returns failure if the
710711
/// remap failed, success otherwise.
711-
LogicalResult getRemappedValues(ValueRange keys,
712-
SmallVectorImpl<Value> &results);
712+
LogicalResult getRemappedValues(ValueRange keys, SmallVector<Value> &results);
713+
714+
virtual void setCurrentTypeConverter(const TypeConverter *converter);
715+
716+
virtual const TypeConverter *getCurrentTypeConverter() const;
717+
718+
/// Populate the operands that are used for constructing the adapter into
719+
/// `remapped`.
720+
virtual LogicalResult getAdapterOperands(StringRef valueDiagTag,
721+
std::optional<Location> inputLoc,
722+
ValueRange values,
723+
SmallVector<Value> &remapped);
713724

714725
//===--------------------------------------------------------------------===//
715726
// PatternRewriter Hooks
@@ -755,6 +766,14 @@ class ConversionPatternRewriter final : public PatternRewriter {
755766
/// Return a reference to the internal implementation.
756767
detail::ConversionPatternRewriterImpl &getImpl();
757768

769+
protected:
770+
/// Protected constructor for `OneShotConversionPatternRewriter`. Does not
771+
/// initialize `impl`.
772+
explicit ConversionPatternRewriter(MLIRContext *ctx);
773+
774+
// Hide unsupported pattern rewriter API.
775+
using OpBuilder::setListener;
776+
758777
private:
759778
// Allow OperationConverter to construct new rewriters.
760779
friend struct OperationConverter;
@@ -765,9 +784,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
765784
explicit ConversionPatternRewriter(MLIRContext *ctx,
766785
const ConversionConfig &config);
767786

768-
// Hide unsupported pattern rewriter API.
769-
using OpBuilder::setListener;
770-
771787
std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
772788
};
773789

mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818

1919
namespace mlir {
2020

21+
class ConversionTarget;
22+
2123
/// This enum controls which ops are put on the worklist during a greedy
2224
/// pattern rewrite.
2325
enum class GreedyRewriteStrictness {
@@ -78,6 +80,8 @@ class GreedyRewriteConfig {
7880
/// excluded.
7981
GreedyRewriteStrictness strictMode = GreedyRewriteStrictness::AnyOp;
8082

83+
bool enableOperationDce = true;
84+
8185
/// An optional listener that should be notified about IR modifications.
8286
RewriterBase::Listener *listener = nullptr;
8387
};
@@ -188,6 +192,10 @@ applyOpPatternsAndFold(ArrayRef<Operation *> ops,
188192
GreedyRewriteConfig config = GreedyRewriteConfig(),
189193
bool *changed = nullptr, bool *allErased = nullptr);
190194

195+
LogicalResult
196+
applyPartialOneShotConversion(Operation *op, const ConversionTarget &target,
197+
const FrozenRewritePatternSet &patterns);
198+
191199
} // namespace mlir
192200

193201
#endif // MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_

mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "mlir/IR/IntegerSet.h"
2424
#include "mlir/IR/MLIRContext.h"
2525
#include "mlir/Transforms/DialectConversion.h"
26+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2627
#include "mlir/Transforms/Passes.h"
2728

2829
namespace mlir {
@@ -563,8 +564,8 @@ class LowerAffinePass
563564
ConversionTarget target(getContext());
564565
target.addLegalDialect<arith::ArithDialect, memref::MemRefDialect,
565566
scf::SCFDialect, VectorDialect>();
566-
if (failed(applyPartialConversion(getOperation(), target,
567-
std::move(patterns))))
567+
if (failed(applyPartialOneShotConversion(getOperation(), target,
568+
std::move(patterns))))
568569
signalPassFailure();
569570
}
570571
};

mlir/lib/Conversion/ArithToLLVM/ArithToLLVM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1818
#include "mlir/IR/TypeUtilities.h"
1919
#include "mlir/Pass/Pass.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2021
#include <type_traits>
2122

2223
namespace mlir {
@@ -479,8 +480,8 @@ struct ArithToLLVMConversionPass
479480
LLVMTypeConverter converter(&getContext(), options);
480481
mlir::arith::populateArithToLLVMConversionPatterns(converter, patterns);
481482

482-
if (failed(applyPartialConversion(getOperation(), target,
483-
std::move(patterns))))
483+
if (failed(applyPartialOneShotConversion(getOperation(), target,
484+
std::move(patterns))))
484485
signalPassFailure();
485486
}
486487
};

mlir/lib/Conversion/ComplexToStandard/ComplexToStandard.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/PatternMatch.h"
1616
#include "mlir/Pass/Pass.h"
1717
#include "mlir/Transforms/DialectConversion.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1819
#include <memory>
1920
#include <type_traits>
2021

@@ -1346,8 +1347,8 @@ void ConvertComplexToStandardPass::runOnOperation() {
13461347
ConversionTarget target(getContext());
13471348
target.addLegalDialect<arith::ArithDialect, math::MathDialect>();
13481349
target.addLegalOp<complex::CreateOp, complex::ImOp, complex::ReOp>();
1349-
if (failed(
1350-
applyPartialConversion(getOperation(), target, std::move(patterns))))
1350+
if (failed(applyPartialOneShotConversion(getOperation(), target,
1351+
std::move(patterns))))
13511352
signalPassFailure();
13521353
}
13531354
} // namespace

mlir/lib/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/PatternMatch.h"
2626
#include "mlir/Pass/Pass.h"
2727
#include "mlir/Transforms/DialectConversion.h"
28+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2829
#include "llvm/ADT/StringRef.h"
2930
#include <functional>
3031

@@ -240,8 +241,8 @@ struct ConvertControlFlowToLLVM
240241
LLVMTypeConverter converter(&getContext(), options);
241242
mlir::cf::populateControlFlowToLLVMConversionPatterns(converter, patterns);
242243

243-
if (failed(applyPartialConversion(getOperation(), target,
244-
std::move(patterns))))
244+
if (failed(applyPartialOneShotConversion(getOperation(), target,
245+
std::move(patterns))))
245246
signalPassFailure();
246247
}
247248
};

mlir/lib/Conversion/MathToLLVM/MathToLLVM.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "mlir/Dialect/Math/IR/Math.h"
1818
#include "mlir/IR/TypeUtilities.h"
1919
#include "mlir/Pass/Pass.h"
20+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2021

2122
namespace mlir {
2223
#define GEN_PASS_DEF_CONVERTMATHTOLLVMPASS
@@ -291,8 +292,8 @@ struct ConvertMathToLLVMPass
291292
LLVMTypeConverter converter(&getContext());
292293
populateMathToLLVMConversionPatterns(converter, patterns, approximateLog1p);
293294
LLVMConversionTarget target(getContext());
294-
if (failed(applyPartialConversion(getOperation(), target,
295-
std::move(patterns))))
295+
if (failed(applyPartialOneShotConversion(getOperation(), target,
296+
std::move(patterns))))
296297
signalPassFailure();
297298
}
298299
};

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mlir/IR/TypeUtilities.h"
2626
#include "mlir/IR/Value.h"
2727
#include "mlir/Pass/Pass.h"
28+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2829
#include "llvm/Support/Debug.h"
2930
#include "llvm/Support/ErrorHandling.h"
3031
#include "llvm/Support/raw_ostream.h"
@@ -475,9 +476,10 @@ struct ConvertNVGPUToNVVMPass
475476
target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
476477
mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
477478
converter, patterns, target);
478-
if (failed(applyPartialConversion(getOperation(), target,
479-
std::move(patterns))))
479+
if (failed(applyPartialOneShotConversion(getOperation(), target,
480+
std::move(patterns))))
480481
signalPassFailure();
482+
// applyPartialConversion
481483
}
482484
};
483485

mlir/lib/Transforms/Utils/DialectConversion.cpp

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1633,6 +1633,9 @@ ConversionPatternRewriter::ConversionPatternRewriter(
16331633
setListener(impl.get());
16341634
}
16351635

1636+
ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
1637+
: PatternRewriter(ctx), impl(nullptr) {}
1638+
16361639
ConversionPatternRewriter::~ConversionPatternRewriter() = default;
16371640

16381641
void ConversionPatternRewriter::replaceOp(Operation *op, Operation *newOp) {
@@ -1717,19 +1720,17 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
17171720

17181721
Value ConversionPatternRewriter::getRemappedValue(Value key) {
17191722
SmallVector<Value> remappedValues;
1720-
if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
1721-
remappedValues)))
1723+
if (failed(getRemappedValues(key, remappedValues)))
17221724
return nullptr;
17231725
return remappedValues.front();
17241726
}
17251727

17261728
LogicalResult
17271729
ConversionPatternRewriter::getRemappedValues(ValueRange keys,
1728-
SmallVectorImpl<Value> &results) {
1730+
SmallVector<Value> &results) {
17291731
if (keys.empty())
17301732
return success();
1731-
return impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
1732-
results);
1733+
return getAdapterOperands("value", /*inputLoc=*/std::nullopt, keys, results);
17331734
}
17341735

17351736
void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
@@ -1819,6 +1820,22 @@ detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
18191820
return *impl;
18201821
}
18211822

1823+
void ConversionPatternRewriter::setCurrentTypeConverter(
1824+
const TypeConverter *converter) {
1825+
impl->currentTypeConverter = converter;
1826+
}
1827+
1828+
const TypeConverter *
1829+
ConversionPatternRewriter::getCurrentTypeConverter() const {
1830+
return impl->currentTypeConverter;
1831+
}
1832+
1833+
LogicalResult ConversionPatternRewriter::getAdapterOperands(
1834+
StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
1835+
SmallVector<Value> &remapped) {
1836+
return impl->remapValues(valueDiagTag, inputLoc, *this, values, remapped);
1837+
}
1838+
18221839
//===----------------------------------------------------------------------===//
18231840
// ConversionPattern
18241841
//===----------------------------------------------------------------------===//
@@ -1827,16 +1844,18 @@ LogicalResult
18271844
ConversionPattern::matchAndRewrite(Operation *op,
18281845
PatternRewriter &rewriter) const {
18291846
auto &dialectRewriter = static_cast<ConversionPatternRewriter &>(rewriter);
1830-
auto &rewriterImpl = dialectRewriter.getImpl();
18311847

18321848
// Track the current conversion pattern type converter in the rewriter.
1833-
llvm::SaveAndRestore currentConverterGuard(rewriterImpl.currentTypeConverter,
1834-
getTypeConverter());
1849+
const TypeConverter *currentTypeConverter =
1850+
dialectRewriter.getCurrentTypeConverter();
1851+
auto resetTypeConverter = llvm::make_scope_exit(
1852+
[&] { dialectRewriter.setCurrentTypeConverter(currentTypeConverter); });
1853+
dialectRewriter.setCurrentTypeConverter(getTypeConverter());
18351854

18361855
// Remap the operands of the operation.
1837-
SmallVector<Value, 4> operands;
1838-
if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
1839-
op->getOperands(), operands))) {
1856+
SmallVector<Value> operands;
1857+
if (failed(dialectRewriter.getAdapterOperands("operand", op->getLoc(),
1858+
op->getOperands(), operands))) {
18401859
return failure();
18411860
}
18421861
return matchAndRewrite(op, operands, dialectRewriter);

0 commit comments

Comments
 (0)