diff --git a/mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/LayeredUnit.h b/mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/LayeredUnit.h index 804eb9a76b..d9dc4c9397 100644 --- a/mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/LayeredUnit.h +++ b/mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/LayeredUnit.h @@ -43,18 +43,13 @@ struct Layer { }; /// @brief A LayeredUnit traverses a program layer-by-layer. -class LayeredUnit : public Unit { +class LayeredUnit : public Unit { public: - using Layers = mlir::SmallVector; - [[nodiscard]] static LayeredUnit fromEntryPointFunction(mlir::func::FuncOp func, std::size_t nqubits); - LayeredUnit(Layout layout, mlir::Region* region, bool restore = false); + LayeredUnit(Layout layout, mlir::Region* region); - [[nodiscard]] mlir::SmallVector next(); - [[nodiscard]] Layers::const_iterator begin() const { return layers_.begin(); } - [[nodiscard]] Layers::const_iterator end() const { return layers_.end(); } [[nodiscard]] const Layer& operator[](std::size_t i) const { return layers_[i]; } @@ -65,6 +60,14 @@ class LayeredUnit : public Unit { #endif private: + friend class Unit; + using Layers = mlir::SmallVector; + using const_iterator = Layers::const_iterator; + + [[nodiscard]] mlir::SmallVector nextImpl(); + [[nodiscard]] const_iterator beginImpl() const { return layers_.begin(); } + [[nodiscard]] const_iterator endImpl() const { return layers_.end(); } + Layers layers_; }; } // namespace mqt::ir::opt diff --git a/mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/SequentialUnit.h b/mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/SequentialUnit.h index 07accbc648..2b6a79c7f5 100644 --- a/mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/SequentialUnit.h +++ b/mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/SequentialUnit.h @@ -22,23 +22,24 @@ namespace mqt::ir::opt { /// @brief A SequentialUnit traverses a program sequentially. -class SequentialUnit : public Unit { +class SequentialUnit : public Unit { public: [[nodiscard]] static SequentialUnit fromEntryPointFunction(mlir::func::FuncOp func, std::size_t nqubits); SequentialUnit(Layout layout, mlir::Region* region, - mlir::Region::OpIterator start, bool restore = false); + mlir::Region::OpIterator start); - SequentialUnit(Layout layout, mlir::Region* region, bool restore = false) - : SequentialUnit(std::move(layout), region, region->op_begin(), restore) { - } - - [[nodiscard]] mlir::SmallVector next(); - [[nodiscard]] mlir::Region::OpIterator begin() const { return start_; } - [[nodiscard]] mlir::Region::OpIterator end() const { return end_; } + SequentialUnit(Layout layout, mlir::Region* region) + : SequentialUnit(std::move(layout), region, region->op_begin()) {} private: + friend class Unit; + + [[nodiscard]] mlir::SmallVector nextImpl(); + [[nodiscard]] mlir::Region::OpIterator beginImpl() const { return start_; } + [[nodiscard]] mlir::Region::OpIterator endImpl() const { return end_; } + mlir::Region::OpIterator start_; mlir::Region::OpIterator end_; }; diff --git a/mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/Unit.h b/mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/Unit.h index dbcb35ecbc..372e05e19a 100644 --- a/mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/Unit.h +++ b/mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/Unit.h @@ -17,26 +17,36 @@ namespace mqt::ir::opt { /// @brief A Unit divides a quantum-classical program into routable sections. -class Unit { +template class Unit { public: - Unit(Layout layout, mlir::Region* region, bool restore = false) - : layout_(std::move(layout)), region_(region), restore_(restore) {} + /// @brief Compute and return subsequent units. + [[nodiscard]] mlir::SmallVector next() { + return static_cast(this)->nextImpl(); + } + + /// @returns an iterator pointing at the first element of the unit. + [[nodiscard]] auto begin() const { + return static_cast(this)->beginImpl(); + } + + /// @returns an iterator pointing at the past-the-end position. + [[nodiscard]] auto end() const { + return static_cast(this)->endImpl(); + } /// @returns the managed layout. [[nodiscard]] Layout& layout() { return layout_; } - /// @returns true iff. the unit has to be restored. - [[nodiscard]] bool restore() const { return restore_; } - protected: + Unit(Layout layout, mlir::Region* region) + : layout_(std::move(layout)), region_(region) {} + /// @brief The layout this unit manages. Layout layout_; /// @brief The region this unit belongs to. mlir::Region* region_; /// @brief Pointer to the next dividing operation. mlir::Operation* divider_{}; - /// @brief Whether to uncompute the inserted SWAP sequence. - bool restore_; }; } // namespace mqt::ir::opt diff --git a/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/LayeredUnit.cpp b/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/LayeredUnit.cpp index c976fec6bb..8c0e46b503 100644 --- a/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/LayeredUnit.cpp +++ b/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/LayeredUnit.cpp @@ -141,8 +141,8 @@ LayeredUnit LayeredUnit::fromEntryPointFunction(mlir::func::FuncOp func, return {std::move(layout), &func.getBody()}; } -LayeredUnit::LayeredUnit(Layout layout, mlir::Region* region, bool restore) - : Unit(std::move(layout), region, restore) { +LayeredUnit::LayeredUnit(Layout layout, mlir::Region* region) + : Unit(std::move(layout), region) { SynchronizationMap sync; mlir::SmallVector curr; @@ -272,7 +272,7 @@ LayeredUnit::LayeredUnit(Layout layout, mlir::Region* region, bool restore) }; } -mlir::SmallVector LayeredUnit::next() { +mlir::SmallVector LayeredUnit::nextImpl() { if (divider_ == nullptr) { return {}; } @@ -283,14 +283,14 @@ mlir::SmallVector LayeredUnit::next() { Layout forLayout(layout_); // Copy layout. forLayout.remapToLoopBody(op); layout_.remapToLoopResults(op); - units.emplace_back(std::move(layout_), region_, restore_); - units.emplace_back(std::move(forLayout), &op.getRegion(), true); + units.emplace_back(std::move(layout_), region_); + units.emplace_back(std::move(forLayout), &op.getRegion()); }) .Case([&](mlir::scf::IfOp op) { - units.emplace_back(layout_, &op.getThenRegion(), true); - units.emplace_back(layout_, &op.getElseRegion(), true); + units.emplace_back(layout_, &op.getThenRegion()); + units.emplace_back(layout_, &op.getElseRegion()); layout_.remapIfResults(op); - units.emplace_back(layout_, region_, restore_); + units.emplace_back(layout_, region_); }) .Default([](auto) { llvm_unreachable("invalid 'next' operation"); }); diff --git a/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/SequentialUnit.cpp b/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/SequentialUnit.cpp index 3169cdf7e6..f2719afe04 100644 --- a/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/SequentialUnit.cpp +++ b/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/SequentialUnit.cpp @@ -40,9 +40,8 @@ SequentialUnit::fromEntryPointFunction(mlir::func::FuncOp func, } SequentialUnit::SequentialUnit(Layout layout, mlir::Region* region, - mlir::Region::OpIterator start, bool restore) - : Unit(std::move(layout), region, restore), start_(start), - end_(region->op_end()) { + mlir::Region::OpIterator start) + : Unit(std::move(layout), region), start_(start), end_(region->op_end()) { mlir::Region::OpIterator it = start_; for (; it != end_; ++it) { mlir::Operation* op = &*it; @@ -54,7 +53,7 @@ SequentialUnit::SequentialUnit(Layout layout, mlir::Region* region, end_ = it; } -mlir::SmallVector SequentialUnit::next() { +mlir::SmallVector SequentialUnit::nextImpl() { if (divider_ == nullptr) { return {}; } @@ -65,16 +64,14 @@ mlir::SmallVector SequentialUnit::next() { Layout forLayout(layout_); // Copy layout. forLayout.remapToLoopBody(op); layout_.remapToLoopResults(op); - units.emplace_back(std::move(layout_), region_, std::next(end_), - restore_); - units.emplace_back(std::move(forLayout), &op.getRegion(), true); + units.emplace_back(std::move(layout_), region_, std::next(end_)); + units.emplace_back(std::move(forLayout), &op.getRegion()); }) .Case([&](mlir::scf::IfOp op) { - units.emplace_back(layout_, &op.getThenRegion(), true); - units.emplace_back(layout_, &op.getElseRegion(), true); + units.emplace_back(layout_, &op.getThenRegion()); + units.emplace_back(layout_, &op.getElseRegion()); layout_.remapIfResults(op); - units.emplace_back(std::move(layout_), region_, std::next(end_), - restore_); + units.emplace_back(std::move(layout_), region_, std::next(end_)); }) .Default([](auto) { llvm_unreachable("invalid 'next' operation"); }); diff --git a/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/AStarRoutingPass.cpp b/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/AStarRoutingPass.cpp index 3d3770d41e..1f1b5e3a5b 100644 --- a/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/AStarRoutingPass.cpp +++ b/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/AStarRoutingPass.cpp @@ -166,11 +166,9 @@ struct AStarRoutingPassSC final .Case([&](ResetOp op) { unit.layout().remap(op); }) .Case([&](MeasureOp op) { unit.layout().remap(op); }) .Case([&](scf::YieldOp op) { - if (unit.restore()) { - rewriter.setInsertionPoint(op); - insertSWAPs(op.getLoc(), llvm::reverse(history), - unit.layout(), rewriter); - } + rewriter.setInsertionPoint(op); + insertSWAPs(op.getLoc(), llvm::reverse(history), + unit.layout(), rewriter); }) .Default([](auto) { llvm_unreachable("unhandled 'curr' operation"); diff --git a/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/NaiveRoutingPass.cpp b/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/NaiveRoutingPass.cpp index 32d4cb5d99..6f917b46fa 100644 --- a/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/NaiveRoutingPass.cpp +++ b/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/NaiveRoutingPass.cpp @@ -141,11 +141,9 @@ struct NaiveRoutingPassSC final .Case([&](ResetOp op) { unit.layout().remap(op); }) .Case([&](MeasureOp op) { unit.layout().remap(op); }) .Case([&](scf::YieldOp op) { - if (unit.restore()) { - rewriter.setInsertionPointAfter(op->getPrevNode()); - insertSWAPs(op.getLoc(), llvm::reverse(history), - unit.layout(), rewriter); - } + rewriter.setInsertionPoint(op); + insertSWAPs(op.getLoc(), llvm::reverse(history), unit.layout(), + rewriter); }); } diff --git a/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/RoutingVerificationPass.cpp b/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/RoutingVerificationPass.cpp index 5f237a9307..42c4e3b3e0 100644 --- a/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/RoutingVerificationPass.cpp +++ b/mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/sc/RoutingVerificationPass.cpp @@ -125,10 +125,6 @@ struct RoutingVerificationPassSC final return success(); }) .Case([&](scf::YieldOp op) -> LogicalResult { - if (!unit.restore()) { - return success(); - } - /// Verify that the layouts match at the end. const auto mappingBefore = unmodified.getCurrentLayout(); const auto mappingNow = unit.layout().getCurrentLayout();