diff --git a/.github/workflows/reusable-mlir-tests.yml b/.github/workflows/reusable-mlir-tests.yml index 5875780a00..02327778a7 100644 --- a/.github/workflows/reusable-mlir-tests.yml +++ b/.github/workflows/reusable-mlir-tests.yml @@ -83,8 +83,8 @@ jobs: - name: Build MLIR lit target run: cmake --build build --config ${{ matrix.coverage && 'Debug' || 'Release' }} --target mqt-core-mlir-lit-test-build-only - - name: Build MLIR unittests - run: cmake --build build --config ${{ matrix.coverage && 'Debug' || 'Release' }} --target mqt-core-mlir-translation-test + - name: Build MLIR translation unittests + run: cmake --build build --config ${{ matrix.coverage && 'Debug' || 'Release' }} --target mqt-core-mlir-translation-test --target mqt-core-mlir-wireiterator-test # Test - name: Run lit tests diff --git a/CHANGELOG.md b/CHANGELOG.md index dcf32d9932..6b478dc03c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ This project adheres to [Semantic Versioning], with the exception that minor rel ## [Unreleased] +### Added + +- ✨ Add bi-directional iterator that traverses the def-use chain of a qubit value ([#1310]) ([**@MatthiasReumann**]) + ### Changed - 👷 Use `munich-quantum-software/setup-mlir` to set up MLIR ([#1294]) ([**@denialhaag**]) @@ -250,6 +254,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool [#1327]: https://github.com/munich-quantum-toolkit/core/pull/1327 +[#1310]: https://github.com/munich-quantum-toolkit/core/pull/1310 [#1300]: https://github.com/munich-quantum-toolkit/core/pull/1300 [#1299]: https://github.com/munich-quantum-toolkit/core/pull/1299 [#1294]: https://github.com/munich-quantum-toolkit/core/pull/1294 diff --git a/mlir/include/mlir/Dialect/MQTOpt/IR/WireIterator.h b/mlir/include/mlir/Dialect/MQTOpt/IR/WireIterator.h new file mode 100644 index 0000000000..79b664a31d --- /dev/null +++ b/mlir/include/mlir/Dialect/MQTOpt/IR/WireIterator.h @@ -0,0 +1,340 @@ +/* + * Copyright (c) 2023 - 2025 Chair for Design Automation, TUM + * Copyright (c) 2025 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include "mlir/Dialect/MQTOpt/IR/MQTOptDialect.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mqt::ir::opt { + +/** + * @brief A bidirectional_iterator traversing the def-use chain of a qubit wire. + * + * The iterator follows the flow of a qubit through a sequence of quantum + * operations in a given region. It respects the semantics of the respective + * quantum operation including control flow constructs (scf::ForOp and + * scf::IfOp). + * + * It treats control flow constructs as a single operation that consumes and + * yields a corresponding number of qubits, without descending into their nested + * regions. + */ +class WireIterator { + /// @returns a view of all input qubits. + [[nodiscard]] static auto getAllInQubits(UnitaryInterface op) { + return llvm::concat(op.getInQubits(), op.getPosCtrlInQubits(), + op.getNegCtrlInQubits()); + } + + /// @returns a view of all output qubits. + [[nodiscard]] static auto getAllOutQubits(UnitaryInterface op) { + return llvm::concat( + op.getOutQubits(), op.getPosCtrlOutQubits(), op.getNegCtrlOutQubits()); + } + + /** + * @brief Find corresponding output from input value for a unitary (Forward). + * + * @note That we don't use the interface method here because + * it creates temporary std::vectors instead of using views. + */ + [[nodiscard]] static mlir::Value findOutput(UnitaryInterface op, + mlir::Value in) { + const auto ins = getAllInQubits(op); + const auto outs = getAllOutQubits(op); + const auto it = llvm::find(ins, in); + assert(it != ins.end() && "input qubit not found in operation"); + const auto index = std::distance(ins.begin(), it); + return *(std::next(outs.begin(), index)); + } + + /** + * @brief Find corresponding input from output value for a unitary (Backward). + * + * @note That we don't use the interface method here because + * it creates temporary std::vectors instead of using views. + */ + [[nodiscard]] static mlir::Value findInput(UnitaryInterface op, + mlir::Value out) { + const auto ins = getAllInQubits(op); + const auto outs = getAllOutQubits(op); + const auto it = llvm::find(outs, out); + assert(it != outs.end() && "output qubit not found in operation"); + const auto index = std::distance(outs.begin(), it); + return *(std::next(ins.begin(), index)); + } + + /** + * @brief Find corresponding result from init argument value (Forward). + */ + [[nodiscard]] static mlir::Value findResult(mlir::scf::ForOp op, + mlir::Value initArg) { + const auto initArgs = op.getInitArgs(); + const auto it = llvm::find(initArgs, initArg); + assert(it != initArgs.end() && "init arg qubit not found in operation"); + const auto index = std::distance(initArgs.begin(), it); + return op->getResult(index); + } + + /** + * @brief Find corresponding init argument from result value (Backward). + */ + [[nodiscard]] static mlir::Value findInitArg(mlir::scf::ForOp op, + mlir::Value res) { + return op.getInitArgs()[cast(res).getResultNumber()]; + } + + /** + * @brief Find corresponding result value from input qubit value (Forward). + * + * @details Recursively traverses the IR "downwards" until the respective + * yield is found. Requires that each branch takes and returns the same + * (possibly modified) qubits. Hence, we can just traverse the then-branch. + */ + [[nodiscard]] static mlir::Value findResult(mlir::scf::IfOp op, + mlir::Value q) { + /// Use the branch with fewer ops. + /// Note: LLVM doesn't guarantee that range_size is in O(1). + /// Might effect performance. + const auto szThen = llvm::range_size(op.getThenRegion().getOps()); + const auto szElse = llvm::range_size(op.getElseRegion().getOps()); + mlir::Region& region = + szElse >= szThen ? op.getThenRegion() : op.getElseRegion(); + + WireIterator it(q, ®ion); + + /// Assumptions: + /// First, there must be a yield. + /// Second, yield is a sentinel. + /// Then: Advance until the yield before the sentinel. + + it = std::prev(std::ranges::next(it, std::default_sentinel)); + assert(isa(*it) && "expected yield op"); + auto yield = cast(*it); + + /// Get the corresponding result. + + const auto results = yield.getResults(); + const auto yieldIt = llvm::find(results, it.q); + assert(yieldIt != results.end() && "yielded qubit not found in operation"); + const auto index = std::distance(results.begin(), yieldIt); + return op->getResult(index); + } + + /** + * @brief Find the first value outside the branch region for a given result + * value (Backward). + * + * @details Recursively traverses the IR "upwards" until a value outside the + * branch region is found. If the iterator's operation does not change during + * backward traversal, it indicates that the def-use chain starts within the + * branch region and does not extend into the parent region. + */ + [[nodiscard]] static mlir::Value findValue(mlir::scf::IfOp op, + mlir::Value q) { + const auto num = cast(q).getResultNumber(); + mlir::Operation* term = op.thenBlock()->getTerminator(); + mlir::scf::YieldOp yield = llvm::cast(term); + mlir::Value v = yield.getResults()[num]; + assert(v != nullptr && "expected yielded value"); + + mlir::Operation* prev{}; + WireIterator it(v, &op.getThenRegion()); + while (it.qubit().getParentRegion() != op->getParentRegion()) { + /// Since the definingOp of q might be a nullptr (BlockArgument), don't + /// immediately dereference the iterator here. + mlir::Operation* curr = it.qubit().getDefiningOp(); + if (curr == prev || curr == nullptr) { + break; + } + prev = *it; + --it; + } + + return it.qubit(); + } + + /** + * @brief Return the first user of a value in a given region. + * @param v The value. + * @param region The targeted region. + * @return A pointer to the user, or nullptr if none exists. + */ + [[nodiscard]] static mlir::Operation* getUserInRegion(mlir::Value v, + mlir::Region* region) { + for (mlir::Operation* user : v.getUsers()) { + if (user->getParentRegion() == region) { + return user; + } + } + return nullptr; + } + +public: + using iterator_category = std::bidirectional_iterator_tag; + using difference_type = std::ptrdiff_t; + using value_type = mlir::Operation*; + + explicit WireIterator() = default; + explicit WireIterator(mlir::Value q, mlir::Region* region) + : currOp(q.getDefiningOp()), q(q), region(region) {} + + [[nodiscard]] mlir::Operation* operator*() const { + assert(!sentinel && "Dereferencing sentinel iterator"); + assert(currOp && "Dereferencing null operation"); + return currOp; + } + + [[nodiscard]] mlir::Value qubit() const { return q; } + + WireIterator& operator++() { + advanceForward(); + return *this; + } + + WireIterator operator++(int) { + auto tmp = *this; + ++*this; + return tmp; + } + + WireIterator& operator--() { + advanceBackward(); + return *this; + } + + WireIterator operator--(int) { + auto tmp = *this; + --*this; + return tmp; + } + + bool operator==(const WireIterator& other) const { + return other.q == q && other.currOp == currOp && other.sentinel == sentinel; + } + + bool operator==([[maybe_unused]] std::default_sentinel_t s) const { + return sentinel; + } + +private: + void advanceForward() { + /// If we are already at the sentinel, there is nothing to do. + if (sentinel) { + return; + } + + /// Find output from input qubit. + /// If there is no output qubit, set `sentinel` to true. + if (q.getDefiningOp() != currOp) { + mlir::TypeSwitch(currOp) + .Case( + [&](UnitaryInterface op) { q = findOutput(op, q); }) + .Case([&](ResetOp op) { q = op.getOutQubit(); }) + .Case([&](MeasureOp op) { q = op.getOutQubit(); }) + .Case( + [&](mlir::scf::ForOp op) { q = findResult(op, q); }) + .Case( + [&](mlir::scf::IfOp op) { q = findResult(op, q); }) + .Case( + [&](auto) { sentinel = true; }) + .Default([&](mlir::Operation* op) { + report_fatal_error("unknown op in def-use chain: " + + op->getName().getStringRef()); + }); + } + + /// Find the next operation. + /// If it is a sentinel there are no more ops. + if (sentinel) { + return; + } + + /// If there are no more uses, set `sentinel` to true. + if (q.use_empty()) { + sentinel = true; + return; + } + + /// Otherwise, search the user in the targeted region. + currOp = getUserInRegion(q, getRegion()); + if (currOp == nullptr) { + /// Since !q.use_empty: must be a branching op. + currOp = q.getUsers().begin()->getParentOp(); + /// For now, just check if it's a scf::IfOp. + /// Theoretically this could also be an scf::IndexSwitch, etc. + assert(isa(currOp)); + } + } + + void advanceBackward() { + /// If we are at the sentinel and move backwards, "revive" the + /// qubit value and operation. + if (sentinel) { + sentinel = false; + return; + } + + /// Get the operation that produces the qubit value. + currOp = q.getDefiningOp(); + + /// If q is a BlockArgument (no defining op), hold. + if (currOp == nullptr) { + return; + } + + /// Find input from output qubit. + /// If there is no input qubit, hold. + mlir::TypeSwitch(currOp) + .Case( + [&](UnitaryInterface op) { q = findInput(op, q); }) + .Case([&](auto op) { q = op.getInQubit(); }) + .Case([&](DeallocQubitOp op) { q = op.getQubit(); }) + .Case( + [&](mlir::scf::ForOp op) { q = findInitArg(op, q); }) + .Case( + [&](mlir::scf::IfOp op) { q = findValue(op, q); }) + .Case([&](auto) { /* hold (no-op) */ }) + .Default([&](mlir::Operation* op) { + report_fatal_error("unknown op in def-use chain: " + + op->getName().getStringRef()); + }); + } + + /** + * @brief Return the active region this iterator uses. + * @return A pointer to the region. + */ + [[nodiscard]] mlir::Region* getRegion() { + return region != nullptr ? region : q.getParentRegion(); + } + + mlir::Operation* currOp{}; + mlir::Value q; + mlir::Region* region{}; + bool sentinel{false}; +}; + +static_assert(std::bidirectional_iterator); +static_assert(std::sentinel_for, + "std::default_sentinel_t must be a sentinel for WireIterator."); +} // namespace mqt::ir::opt diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index fb4775eaa4..46421d9bc1 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -7,3 +7,4 @@ # Licensed under the MIT License add_subdirectory(translation) +add_subdirectory(dialect) diff --git a/mlir/unittests/dialect/CMakeLists.txt b/mlir/unittests/dialect/CMakeLists.txt new file mode 100644 index 0000000000..db26229f6f --- /dev/null +++ b/mlir/unittests/dialect/CMakeLists.txt @@ -0,0 +1,21 @@ +# Copyright (c) 2023 - 2025 Chair for Design Automation, TUM +# Copyright (c) 2025 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(testname "mqt-core-mlir-wireiterator-test") +file(GLOB_RECURSE WIREITERATOR_TEST_SOURCES *.cpp) + +if(NOT TARGET ${testname}) + # create an executable in which the tests will be stored + add_executable(${testname} ${WIREITERATOR_TEST_SOURCES}) + # link the Google test infrastructure and a default main function to the test executable. + target_link_libraries(${testname} PRIVATE GTest::gtest_main MLIRParser MLIRMQTOpt MLIRSCFDialect + MLIRArithDialect MLIRIndexDialect) + # discover tests + gtest_discover_tests(${testname} DISCOVERY_TIMEOUT 60) + set_target_properties(${testname} PROPERTIES FOLDER unittests) +endif() diff --git a/mlir/unittests/dialect/test_wireiterator.cpp b/mlir/unittests/dialect/test_wireiterator.cpp new file mode 100644 index 0000000000..562dd4a535 --- /dev/null +++ b/mlir/unittests/dialect/test_wireiterator.cpp @@ -0,0 +1,354 @@ +/* + * Copyright (c) 2023 - 2025 Chair for Design Automation, TUM + * Copyright (c) 2025 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/MQTOpt/IR/MQTOptDialect.h" +#include "mlir/Dialect/MQTOpt/IR/WireIterator.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; +using namespace mqt::ir::opt; + +namespace { +/** @returns a module containing the circuit from the "Tackling the Qubit + * Mapping Problem for NISQ-Era Quantum Devices" paper by Li et al. + */ +OwningOpRef getModule(MLIRContext& ctx) { + const char* ir = R"mlir( +module { + %0 = mqtopt.allocQubit + %1 = mqtopt.allocQubit + %out_qubits = mqtopt.h() %0 : !mqtopt.Qubit + %out_qubits_0 = mqtopt.h() %1 : !mqtopt.Qubit + %out_qubits_1 = mqtopt.z() %out_qubits : !mqtopt.Qubit + %out_qubits_2, %pos_ctrl_out_qubits = mqtopt.x() %out_qubits_0 ctrl %out_qubits_1 : !mqtopt.Qubit ctrl !mqtopt.Qubit + %out_qubits_3 = mqtopt.h() %out_qubits_2 : !mqtopt.Qubit + %out_qubits_4, %pos_ctrl_out_qubits_5 = mqtopt.x() %pos_ctrl_out_qubits ctrl %out_qubits_3 : !mqtopt.Qubit ctrl !mqtopt.Qubit + %false = arith.constant false + %2:2 = scf.if %false -> (!mqtopt.Qubit, !mqtopt.Qubit) { + %out_qubits_7 = mqtopt.y() %out_qubits_4 : !mqtopt.Qubit + scf.yield %out_qubits_7, %pos_ctrl_out_qubits_5 : !mqtopt.Qubit, !mqtopt.Qubit + } else { + scf.yield %out_qubits_4, %pos_ctrl_out_qubits_5 : !mqtopt.Qubit, !mqtopt.Qubit + } + %idx0 = index.constant 0 + %idx8 = index.constant 8 + %idx1 = index.constant 1 + %3:2 = scf.for %arg0 = %idx0 to %idx8 step %idx1 iter_args(%arg1 = %2#0, %arg2 = %2#1) -> (!mqtopt.Qubit, !mqtopt.Qubit) { + %out_qubits_7 = mqtopt.h() %arg1 : !mqtopt.Qubit + %out_qubits_8 = mqtopt.h() %arg2 : !mqtopt.Qubit + scf.yield %out_qubits_7, %out_qubits_8 : !mqtopt.Qubit, !mqtopt.Qubit + } + mqtopt.deallocQubit %3#0 + mqtopt.deallocQubit %3#1 + + %4 = mqtopt.qubit 42 + %5 = mqtopt.reset %4 + %out_qubits_6 = mqtopt.h() %5 : !mqtopt.Qubit +} +)mlir"; + return parseSourceString(ir, &ctx); +} + +std::string toString(Operation* op) { + std::string opStr; + llvm::raw_string_ostream os(opStr); + os << *op; + os.flush(); + return opStr; +} + +void checkOperationEqual(Operation* op, const std::string& expected) { + ASSERT_EQ(expected, toString(op)); +} + +void checkOperationStartsWith(Operation* op, const std::string& prefix) { + ASSERT_TRUE(toString(op).starts_with(prefix)); +} +} // namespace + +class WireIteratorTest : public ::testing::Test { +protected: + std::unique_ptr context; + + void SetUp() override { + DialectRegistry registry; + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + + context = std::make_unique(); + context->appendDialectRegistry(registry); + context->loadAllAvailableDialects(); + } +}; + +TEST_F(WireIteratorTest, TestForward) { + + /// + /// Test the forward iteration. + /// + + auto module = getModule(*context); + auto alloc = *(module->getOps().begin()); + auto q = alloc.getQubit(); + WireIterator it(q, q.getParentRegion()); + + checkOperationEqual(*it, "%0 = mqtopt.allocQubit"); + + ++it; + checkOperationEqual(*it, "%out_qubits = mqtopt.h() %0 : !mqtopt.Qubit"); + + ++it; + checkOperationEqual(*it, + "%out_qubits_1 = mqtopt.z() %out_qubits : !mqtopt.Qubit"); + + ++it; + checkOperationEqual( + *it, "%out_qubits_2, %pos_ctrl_out_qubits = mqtopt.x() %out_qubits_0 " + "ctrl %out_qubits_1 : !mqtopt.Qubit ctrl !mqtopt.Qubit"); + + ++it; + checkOperationEqual( + *it, + "%out_qubits_4, %pos_ctrl_out_qubits_5 = mqtopt.x() %pos_ctrl_out_qubits " + "ctrl %out_qubits_3 : !mqtopt.Qubit ctrl !mqtopt.Qubit"); + + ++it; + checkOperationStartsWith( + *it, "%2:2 = scf.if %false -> (!mqtopt.Qubit, !mqtopt.Qubit)"); + + ++it; + checkOperationStartsWith(*it, + "%3:2 = scf.for %arg0 = %idx0 to %idx8 step %idx1"); + + ++it; + checkOperationEqual(*it, "mqtopt.deallocQubit %3#0"); + + ++it; + ASSERT_EQ(it, std::default_sentinel); + + ++it; + ASSERT_EQ(it, std::default_sentinel); +} + +TEST_F(WireIteratorTest, TestBackward) { + + /// + /// Test the backward iteration. + /// + + auto module = getModule(*context); + auto allocs = module->getOps(); + const auto allocRng = llvm::make_range(allocs.begin(), allocs.end()); + const auto allocVec = llvm::to_vector(allocRng); + auto alloc = allocVec[1]; + auto q = alloc.getQubit(); + WireIterator it(q, q.getParentRegion()); + const WireIterator begin(it); + + ASSERT_EQ(it, begin); + + for (; it != std::default_sentinel; ++it) { + llvm::dbgs() << **it << '\n'; /// Keep for debugging purposes. + } + + ASSERT_EQ(it, std::default_sentinel); + + --it; + checkOperationEqual(*it, "mqtopt.deallocQubit %3#1"); + + --it; + checkOperationStartsWith(*it, + "%3:2 = scf.for %arg0 = %idx0 to %idx8 step %idx1"); + + --it; + checkOperationStartsWith( + *it, "%2:2 = scf.if %false -> (!mqtopt.Qubit, !mqtopt.Qubit)"); + + --it; + checkOperationEqual( + *it, + "%out_qubits_4, %pos_ctrl_out_qubits_5 = mqtopt.x() %pos_ctrl_out_qubits " + "ctrl %out_qubits_3 : !mqtopt.Qubit ctrl !mqtopt.Qubit"); + + --it; + checkOperationEqual( + *it, "%out_qubits_3 = mqtopt.h() %out_qubits_2 : !mqtopt.Qubit"); + + --it; + checkOperationEqual( + *it, "%out_qubits_2, %pos_ctrl_out_qubits = mqtopt.x() %out_qubits_0 " + "ctrl %out_qubits_1 : !mqtopt.Qubit ctrl !mqtopt.Qubit"); + + --it; + checkOperationEqual(*it, "%out_qubits_0 = mqtopt.h() %1 : !mqtopt.Qubit"); + + --it; + checkOperationEqual(*it, "%1 = mqtopt.allocQubit"); + + ASSERT_EQ(it, begin); + + --it; + checkOperationEqual(*it, "%1 = mqtopt.allocQubit"); + + ASSERT_EQ(it, begin); +} + +TEST_F(WireIteratorTest, TestForwardAndBackward) { + + /// + /// Test the forward as well as the backward iteration. + /// + + auto module = getModule(*context); + auto alloc = *(module->getOps().begin()); + auto q = alloc.getQubit(); + WireIterator it(q, q.getParentRegion()); + const WireIterator begin(it); + + checkOperationEqual(*it, "%0 = mqtopt.allocQubit"); + + ++it; + checkOperationEqual(*it, "%out_qubits = mqtopt.h() %0 : !mqtopt.Qubit"); + + ++it; + checkOperationEqual(*it, + "%out_qubits_1 = mqtopt.z() %out_qubits : !mqtopt.Qubit"); + + ++it; + checkOperationEqual( + *it, "%out_qubits_2, %pos_ctrl_out_qubits = mqtopt.x() %out_qubits_0 " + "ctrl %out_qubits_1 : !mqtopt.Qubit ctrl !mqtopt.Qubit"); + + --it; + checkOperationEqual(*it, + "%out_qubits_1 = mqtopt.z() %out_qubits : !mqtopt.Qubit"); + + --it; + checkOperationEqual(*it, "%out_qubits = mqtopt.h() %0 : !mqtopt.Qubit"); + + --it; + checkOperationEqual(*it, "%0 = mqtopt.allocQubit"); + + ASSERT_EQ(it, begin); + + for (; it != std::default_sentinel; ++it) { + llvm::dbgs() << **it << '\n'; /// Keep for debugging purposes. + } + + ASSERT_EQ(it, std::default_sentinel); + + it = std::prev(it); // Back to last non-sentinel item. + + for (; it != begin; --it) { + llvm::dbgs() << **it << '\n'; /// Keep for debugging purposes. + } + + ASSERT_EQ(it, begin); +} + +TEST_F(WireIteratorTest, TestRecursiveUse) { + + /// + /// Test the recursive use of the iterator. + /// + + auto module = getModule(*context); + auto alloc = *(module->getOps().begin()); + auto q = alloc.getQubit(); + WireIterator it(q, q.getParentRegion()); + + /// Advance until 'scf.for'. + for (; it != std::default_sentinel; ++it) { + if (isa(*it)) { + break; + } + llvm::dbgs() << **it << '\n'; /// Keep for debugging purposes. + } + + auto loop = cast(*it); + for (auto [iter, init] : + llvm::zip(loop.getRegionIterArgs(), loop.getInitArgs())) { + if (init == it.qubit()) { + WireIterator rec(iter, &loop.getRegion()); + const WireIterator recBegin(rec); + rec--; + + ASSERT_EQ(rec, recBegin); // Test blockargument handling. + + rec++; + checkOperationEqual(*rec, + "%out_qubits_7 = mqtopt.h() %arg1 : !mqtopt.Qubit"); + + rec++; + checkOperationEqual(*rec, "scf.yield %out_qubits_7, %out_qubits_8 : " + "!mqtopt.Qubit, !mqtopt.Qubit"); + } + } +} + +TEST_F(WireIteratorTest, TestStaticQubit) { + + /// + /// Test the iteration with a static qubit. + /// + + auto module = getModule(*context); + auto qubit = *(module->getOps().begin()); + auto q = qubit.getQubit(); + WireIterator it(q, q.getParentRegion()); + const WireIterator begin(it); + + checkOperationEqual(*it, "%4 = mqtopt.qubit 42"); + + ++it; + checkOperationEqual(*it, "%5 = mqtopt.reset %4"); + + ++it; + checkOperationEqual(*it, "%out_qubits_6 = mqtopt.h() %5 : !mqtopt.Qubit"); + + ++it; + ASSERT_EQ(it, std::default_sentinel); + + --it; + checkOperationEqual(*it, "%out_qubits_6 = mqtopt.h() %5 : !mqtopt.Qubit"); + ASSERT_EQ(it.qubit(), (*it)->getResult(0)); // q = %out_qubits_6 + + --it; + checkOperationEqual(*it, "%out_qubits_6 = mqtopt.h() %5 : !mqtopt.Qubit"); + ASSERT_EQ(it.qubit(), (*it)->getOperand(0)); // q = %5 + + --it; + checkOperationEqual(*it, "%5 = mqtopt.reset %4"); + + --it; + checkOperationEqual(*it, "%4 = mqtopt.qubit 42"); + + ASSERT_EQ(it, begin); +}