Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,13 @@ struct Layer {
};

/// @brief A LayeredUnit traverses a program layer-by-layer.
class LayeredUnit : public Unit {
class LayeredUnit : public Unit<LayeredUnit> {
public:
using Layers = mlir::SmallVector<Layer, 0>;

[[nodiscard]] static LayeredUnit
fromEntryPointFunction(mlir::func::FuncOp func, std::size_t nqubits);

LayeredUnit(Layout layout, mlir::Region* region, bool restore = false);
LayeredUnit(Layout layout, mlir::Region* region);

[[nodiscard]] mlir::SmallVector<LayeredUnit, 3> next();
[[nodiscard]] Layers::const_iterator begin() const { return layers_.begin(); }
[[nodiscard]] Layers::const_iterator end() const { return layers_.end(); }
[[nodiscard]] const Layer& operator[](std::size_t i) const {
return layers_[i];
}
Expand All @@ -65,6 +60,14 @@ class LayeredUnit : public Unit {
#endif

private:
friend class Unit<LayeredUnit>;
using Layers = mlir::SmallVector<Layer, 0>;
using const_iterator = Layers::const_iterator;

[[nodiscard]] mlir::SmallVector<LayeredUnit, 3> nextImpl();
[[nodiscard]] const_iterator beginImpl() const { return layers_.begin(); }
[[nodiscard]] const_iterator endImpl() const { return layers_.end(); }

Layers layers_;
};
} // namespace mqt::ir::opt
Original file line number Diff line number Diff line change
Expand Up @@ -22,23 +22,24 @@
namespace mqt::ir::opt {

/// @brief A SequentialUnit traverses a program sequentially.
class SequentialUnit : public Unit {
class SequentialUnit : public Unit<SequentialUnit> {
public:
[[nodiscard]] static SequentialUnit
fromEntryPointFunction(mlir::func::FuncOp func, std::size_t nqubits);

SequentialUnit(Layout layout, mlir::Region* region,
mlir::Region::OpIterator start, bool restore = false);
mlir::Region::OpIterator start);

SequentialUnit(Layout layout, mlir::Region* region, bool restore = false)
: SequentialUnit(std::move(layout), region, region->op_begin(), restore) {
}

[[nodiscard]] mlir::SmallVector<SequentialUnit, 3> next();
[[nodiscard]] mlir::Region::OpIterator begin() const { return start_; }
[[nodiscard]] mlir::Region::OpIterator end() const { return end_; }
SequentialUnit(Layout layout, mlir::Region* region)
: SequentialUnit(std::move(layout), region, region->op_begin()) {}

private:
friend class Unit<SequentialUnit>;

[[nodiscard]] mlir::SmallVector<SequentialUnit, 3> nextImpl();
[[nodiscard]] mlir::Region::OpIterator beginImpl() const { return start_; }
[[nodiscard]] mlir::Region::OpIterator endImpl() const { return end_; }

mlir::Region::OpIterator start_;
mlir::Region::OpIterator end_;
};
Expand Down
26 changes: 18 additions & 8 deletions mlir/include/mlir/Dialect/MQTOpt/Transforms/Transpilation/Unit.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,26 +17,36 @@
namespace mqt::ir::opt {

/// @brief A Unit divides a quantum-classical program into routable sections.
class Unit {
template <class Derived> class Unit {
public:
Unit(Layout layout, mlir::Region* region, bool restore = false)
: layout_(std::move(layout)), region_(region), restore_(restore) {}
/// @brief Compute and return subsequent units.
[[nodiscard]] mlir::SmallVector<Derived, 3> next() {
return static_cast<Derived*>(this)->nextImpl();
}

/// @returns an iterator pointing at the first element of the unit.
[[nodiscard]] auto begin() const {
return static_cast<const Derived*>(this)->beginImpl();
}

/// @returns an iterator pointing at the past-the-end position.
[[nodiscard]] auto end() const {
return static_cast<const Derived*>(this)->endImpl();
}

/// @returns the managed layout.
[[nodiscard]] Layout& layout() { return layout_; }

/// @returns true iff. the unit has to be restored.
[[nodiscard]] bool restore() const { return restore_; }

protected:
Unit(Layout layout, mlir::Region* region)
: layout_(std::move(layout)), region_(region) {}

/// @brief The layout this unit manages.
Layout layout_;
/// @brief The region this unit belongs to.
mlir::Region* region_;
/// @brief Pointer to the next dividing operation.
mlir::Operation* divider_{};
/// @brief Whether to uncompute the inserted SWAP sequence.
bool restore_;
};

} // namespace mqt::ir::opt
16 changes: 8 additions & 8 deletions mlir/lib/Dialect/MQTOpt/Transforms/Transpilation/LayeredUnit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ LayeredUnit LayeredUnit::fromEntryPointFunction(mlir::func::FuncOp func,
return {std::move(layout), &func.getBody()};
}

LayeredUnit::LayeredUnit(Layout layout, mlir::Region* region, bool restore)
: Unit(std::move(layout), region, restore) {
LayeredUnit::LayeredUnit(Layout layout, mlir::Region* region)
: Unit(std::move(layout), region) {
SynchronizationMap sync;

mlir::SmallVector<Wire, 0> curr;
Expand Down Expand Up @@ -272,7 +272,7 @@ LayeredUnit::LayeredUnit(Layout layout, mlir::Region* region, bool restore)
};
}

mlir::SmallVector<LayeredUnit, 3> LayeredUnit::next() {
mlir::SmallVector<LayeredUnit, 3> LayeredUnit::nextImpl() {
if (divider_ == nullptr) {
return {};
}
Expand All @@ -283,14 +283,14 @@ mlir::SmallVector<LayeredUnit, 3> LayeredUnit::next() {
Layout forLayout(layout_); // Copy layout.
forLayout.remapToLoopBody(op);
layout_.remapToLoopResults(op);
units.emplace_back(std::move(layout_), region_, restore_);
units.emplace_back(std::move(forLayout), &op.getRegion(), true);
units.emplace_back(std::move(layout_), region_);
units.emplace_back(std::move(forLayout), &op.getRegion());
})
.Case<mlir::scf::IfOp>([&](mlir::scf::IfOp op) {
units.emplace_back(layout_, &op.getThenRegion(), true);
units.emplace_back(layout_, &op.getElseRegion(), true);
units.emplace_back(layout_, &op.getThenRegion());
units.emplace_back(layout_, &op.getElseRegion());
layout_.remapIfResults(op);
units.emplace_back(layout_, region_, restore_);
units.emplace_back(layout_, region_);
})
.Default([](auto) { llvm_unreachable("invalid 'next' operation"); });

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ SequentialUnit::fromEntryPointFunction(mlir::func::FuncOp func,
}

SequentialUnit::SequentialUnit(Layout layout, mlir::Region* region,
mlir::Region::OpIterator start, bool restore)
: Unit(std::move(layout), region, restore), start_(start),
end_(region->op_end()) {
mlir::Region::OpIterator start)
: Unit(std::move(layout), region), start_(start), end_(region->op_end()) {
mlir::Region::OpIterator it = start_;
for (; it != end_; ++it) {
mlir::Operation* op = &*it;
Expand All @@ -54,7 +53,7 @@ SequentialUnit::SequentialUnit(Layout layout, mlir::Region* region,
end_ = it;
}

mlir::SmallVector<SequentialUnit, 3> SequentialUnit::next() {
mlir::SmallVector<SequentialUnit, 3> SequentialUnit::nextImpl() {
if (divider_ == nullptr) {
return {};
}
Expand All @@ -65,16 +64,14 @@ mlir::SmallVector<SequentialUnit, 3> SequentialUnit::next() {
Layout forLayout(layout_); // Copy layout.
forLayout.remapToLoopBody(op);
layout_.remapToLoopResults(op);
units.emplace_back(std::move(layout_), region_, std::next(end_),
restore_);
units.emplace_back(std::move(forLayout), &op.getRegion(), true);
units.emplace_back(std::move(layout_), region_, std::next(end_));
units.emplace_back(std::move(forLayout), &op.getRegion());
})
.Case<mlir::scf::IfOp>([&](mlir::scf::IfOp op) {
units.emplace_back(layout_, &op.getThenRegion(), true);
units.emplace_back(layout_, &op.getElseRegion(), true);
units.emplace_back(layout_, &op.getThenRegion());
units.emplace_back(layout_, &op.getElseRegion());
layout_.remapIfResults(op);
units.emplace_back(std::move(layout_), region_, std::next(end_),
restore_);
units.emplace_back(std::move(layout_), region_, std::next(end_));
})
.Default([](auto) { llvm_unreachable("invalid 'next' operation"); });

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,9 @@ struct AStarRoutingPassSC final
.Case<ResetOp>([&](ResetOp op) { unit.layout().remap(op); })
.Case<MeasureOp>([&](MeasureOp op) { unit.layout().remap(op); })
.Case<scf::YieldOp>([&](scf::YieldOp op) {
if (unit.restore()) {
rewriter.setInsertionPoint(op);
insertSWAPs(op.getLoc(), llvm::reverse(history),
unit.layout(), rewriter);
}
rewriter.setInsertionPoint(op);
insertSWAPs(op.getLoc(), llvm::reverse(history),
unit.layout(), rewriter);
})
.Default([](auto) {
llvm_unreachable("unhandled 'curr' operation");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,9 @@ struct NaiveRoutingPassSC final
.Case<ResetOp>([&](ResetOp op) { unit.layout().remap(op); })
.Case<MeasureOp>([&](MeasureOp op) { unit.layout().remap(op); })
.Case<scf::YieldOp>([&](scf::YieldOp op) {
if (unit.restore()) {
rewriter.setInsertionPointAfter(op->getPrevNode());
insertSWAPs(op.getLoc(), llvm::reverse(history),
unit.layout(), rewriter);
}
rewriter.setInsertionPoint(op);
insertSWAPs(op.getLoc(), llvm::reverse(history), unit.layout(),
rewriter);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,6 @@ struct RoutingVerificationPassSC final
return success();
})
.Case<scf::YieldOp>([&](scf::YieldOp op) -> LogicalResult {
if (!unit.restore()) {
return success();
}

/// Verify that the layouts match at the end.
const auto mappingBefore = unmodified.getCurrentLayout();
const auto mappingNow = unit.layout().getCurrentLayout();
Expand Down
Loading