diff --git a/mlir/include/mlir/Compiler/CompilerPipeline.h b/mlir/include/mlir/Compiler/CompilerPipeline.h index e831ad907..043446888 100644 --- a/mlir/include/mlir/Compiler/CompilerPipeline.h +++ b/mlir/include/mlir/Compiler/CompilerPipeline.h @@ -39,6 +39,9 @@ struct QuantumCompilerConfig { /// Print IR after each stage bool printIRAfterAllStages = false; + + /// Enable quaternion-based rotation gate merging + bool mergeRotationGates = false; }; /** diff --git a/mlir/include/mlir/Dialect/QCO/CMakeLists.txt b/mlir/include/mlir/Dialect/QCO/CMakeLists.txt index b181a84fe..3b0a561d0 100644 --- a/mlir/include/mlir/Dialect/QCO/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/QCO/CMakeLists.txt @@ -7,3 +7,4 @@ # Licensed under the MIT License add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/include/mlir/Dialect/QCO/Transforms/CMakeLists.txt b/mlir/include/mlir/Dialect/QCO/Transforms/CMakeLists.txt new file mode 100644 index 000000000..66bcfa12b --- /dev/null +++ b/mlir/include/mlir/Dialect/QCO/Transforms/CMakeLists.txt @@ -0,0 +1,12 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name QCO) +add_public_tablegen_target(MLIRQCOTransformsIncGen) +add_mlir_doc(Passes QCOPasses Passes/ -gen-pass-doc) diff --git a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.h b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.h new file mode 100644 index 000000000..8c1286b7e --- /dev/null +++ b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.h @@ -0,0 +1,21 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#pragma once + +#include // from @llvm-project + +namespace mlir::qco { +#define GEN_PASS_DECL +#include "mlir/Dialect/QCO/Transforms/Passes.h.inc" // IWYU pragma: export + +#define GEN_PASS_REGISTRATION +#include "mlir/Dialect/QCO/Transforms/Passes.h.inc" // IWYU pragma: export +} // namespace mlir::qco diff --git a/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td new file mode 100644 index 000000000..8c2c18082 --- /dev/null +++ b/mlir/include/mlir/Dialect/QCO/Transforms/Passes.td @@ -0,0 +1,20 @@ +// Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +// Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +// All rights reserved. +// +// SPDX-License-Identifier: MIT +// +// Licensed under the MIT License + +include "mlir/Pass/PassBase.td" + +def MergeRotationGates : Pass<"merge-rotation-gates", +"mlir::ModuleOp"> { + let summary = "Merge rotation gates using quaternion-based fusion"; + let description = [{ + Merges consecutive rotation gates of different types (rx, ry, rz, u) + using quaternion representation and the Hamilton product. + }]; + + let dependentDialects = ["::mlir::math::MathDialect"]; +} diff --git a/mlir/lib/Compiler/CMakeLists.txt b/mlir/lib/Compiler/CMakeLists.txt index 97b700b5c..6d0b33a9f 100644 --- a/mlir/lib/Compiler/CMakeLists.txt +++ b/mlir/lib/Compiler/CMakeLists.txt @@ -20,6 +20,7 @@ add_mlir_library( QCToQCO QCOToQC QCToQIR + MLIRQCOTransforms MQT::MLIRSupport MQT::ProjectOptions DISABLE_INSTALL) diff --git a/mlir/lib/Compiler/CompilerPipeline.cpp b/mlir/lib/Compiler/CompilerPipeline.cpp index 02183790f..2659055b0 100644 --- a/mlir/lib/Compiler/CompilerPipeline.cpp +++ b/mlir/lib/Compiler/CompilerPipeline.cpp @@ -13,6 +13,7 @@ #include "mlir/Conversion/QCOToQC/QCOToQC.h" #include "mlir/Conversion/QCToQCO/QCToQCO.h" #include "mlir/Conversion/QCToQIR/QCToQIR.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" #include "mlir/Support/PrettyPrinting.h" #include @@ -161,10 +162,16 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, // Stage 5: Optimization passes // TODO: Add optimization passes - addCleanupPasses(pm); - if (failed(pm.run(module))) { - return failure(); + + // quaternion gate merging pass + if (config_.mergeRotationGates) { + pm.addPass(mlir::qco::createMergeRotationGates()); + if (failed(pm.run(module))) { + return failure(); + } + pm.clear(); } + if (record != nullptr && config_.recordIntermediates) { record->afterOptimization = captureIR(module); if (config_.printIRAfterAllStages) { @@ -172,7 +179,6 @@ QuantumCompilerPipeline::runPipeline(ModuleOp module, totalStages); } } - pm.clear(); // Stage 6: QCO canonicalization addCleanupPasses(pm); diff --git a/mlir/lib/Dialect/QCO/CMakeLists.txt b/mlir/lib/Dialect/QCO/CMakeLists.txt index e87794ff1..1dad5ba80 100644 --- a/mlir/lib/Dialect/QCO/CMakeLists.txt +++ b/mlir/lib/Dialect/QCO/CMakeLists.txt @@ -8,3 +8,4 @@ add_subdirectory(Builder) add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt b/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt new file mode 100644 index 000000000..dd8af2abc --- /dev/null +++ b/mlir/lib/Dialect/QCO/Transforms/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS) +set(LIBRARIES ${dialect_libs} MQT::CoreIR) +add_compile_options(-fexceptions) + +message(STATUS "MLIR_DIALECT_LIBS contains: ${dialect_libs}") + +file(GLOB_RECURSE TRANSFORMS_SOURCES *.cpp) + +add_mlir_library(MLIRQCOTransforms ${TRANSFORMS_SOURCES} LINK_LIBS ${LIBRARIES} DEPENDS + MLIRQCOTransformsIncGen) + +# collect header files +file(GLOB_RECURSE TRANSFORMS_HEADERS_SOURCE + ${MQT_MLIR_SOURCE_INCLUDE_DIR}/mlir/Dialect/QCO/Transforms/*.h) +file(GLOB_RECURSE TRANSFORMS_HEADERS_BUILD + ${MQT_MLIR_BUILD_INCLUDE_DIR}/mlir/Dialect/QCO/Transforms/*.inc) + +# add public headers using file sets +target_sources( + MLIRQCOTransforms + PUBLIC FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_SOURCE_INCLUDE_DIR} + FILES + ${TRANSFORMS_HEADERS_SOURCE} + FILE_SET + HEADERS + BASE_DIRS + ${MQT_MLIR_BUILD_INCLUDE_DIR} + FILES + ${TRANSFORMS_HEADERS_BUILD}) diff --git a/mlir/lib/Dialect/QCO/Transforms/QuaternionMergeRotationGates.cpp b/mlir/lib/Dialect/QCO/Transforms/QuaternionMergeRotationGates.cpp new file mode 100644 index 000000000..139af0c99 --- /dev/null +++ b/mlir/lib/Dialect/QCO/Transforms/QuaternionMergeRotationGates.cpp @@ -0,0 +1,421 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace mlir::qco { + +#define GEN_PASS_DEF_MERGEROTATIONGATES +#include "mlir/Dialect/QCO/Transforms/Passes.h.inc" + +/** + * @brief Pattern that merges consecutive rotation gates using quaternion + * multiplication. + */ +struct MergeRotationGatesPattern final + : mlir::OpInterfaceRewritePattern { + explicit MergeRotationGatesPattern(mlir::MLIRContext* context) + : OpInterfaceRewritePattern(context) {} + + struct Quaternion { + mlir::Value w; + mlir::Value x; + mlir::Value y; + mlir::Value z; + }; + + enum class RotationAxis { X, Y, Z }; + + static constexpr std::array MERGEABLE_GATES = { + "u", "rx", "ry", "rz"}; + + /** + * @brief Checks if an operation is a mergeable rotation gate (rx, ry, rz, u). + * + * @param name Name of the operation to check + * @return True if mergeable, false otherwise + */ + static bool isMergeable(std::string_view name) { + return std::ranges::find(MERGEABLE_GATES, name) != MERGEABLE_GATES.end(); + } + + /** + * @brief Checks if two gates require quaternion-based merging. + * + * Returns true for different gate types (e.g., RX+RY) or two U-gates. + * Same-axis rotations (e.g., RX+RX) use angle addition and aren't handled + * here. + * + * @param a The first gate + * @param b The second gate + * @return True if quaternion-based merging should be used, false otherwise + */ + [[nodiscard]] static bool areQuaternionMergeable(mlir::Operation& a, + mlir::Operation& b) { + const auto aName = a.getName().stripDialect().str(); + const auto bName = b.getName().stripDialect().str(); + + if (!(isMergeable(aName) && isMergeable(bName))) { + return false; + } + return (aName != bName) || (aName == "u" && bName == "u"); + } + + /** + * @brief Converts a single-axis rotation to quaternion representation. + * + * Uses half-angle formulas: + * RX(a) = Q(cos(a/2), sin(a/2), 0, 0) + * RY(a) = Q(cos(a/2), 0, sin(a/2), 0) + * RZ(a) = Q(cos(a/2), 0, 0, sin(a/2)) + * + * @see + * https://en.wikipedia.org/wiki/Conversion_between_quaternions_and_Euler_angles + * @param angle The rotation angle + * @param axis The rotation axis (X, Y, or Z) + * @param loc Location in the IR + * @param rewriter Pattern rewriter for creating new operations + * @return Quaternion representing the rotation + */ + static Quaternion createAxisQuaternion(mlir::Value angle, RotationAxis axis, + mlir::Location loc, + mlir::PatternRewriter& rewriter) { + auto floatType = angle.getType(); + + // constant 0.0 + auto zeroAttr = rewriter.getFloatAttr(floatType, 0.0); + auto zero = rewriter.create(loc, zeroAttr); + + // constant 2.0 + auto twoAttr = rewriter.getFloatAttr(floatType, 2.0); + auto two = rewriter.create(loc, twoAttr); + + auto half = rewriter.create(loc, angle, two); + // cos(angle/2) + auto cos = rewriter.create(loc, floatType, half); + // sin(angle/2) + auto sin = rewriter.create(loc, floatType, half); + + switch (axis) { + case RotationAxis::X: + return {.w = cos, .x = sin, .y = zero, .z = zero}; + case RotationAxis::Y: + return {.w = cos, .x = zero, .y = sin, .z = zero}; + case RotationAxis::Z: + return {.w = cos, .x = zero, .y = zero, .z = sin}; + } + } + + /** + * @brief Converts a rotation gate (RX, RY, RZ, or U) to quaternion + * representation. + * + * @param op The rotation gate to convert + * @param rewriter Pattern rewriter for creating new operations + * @return Quaternion representing the rotation gate + */ + static Quaternion quaternionFromRotation(UnitaryOpInterface op, + mlir::PatternRewriter& rewriter) { + auto const type = op->getName().stripDialect().str(); + + if (type == "u") { + return quaternionFromUGate(op, rewriter); + } + + auto loc = op->getLoc(); + auto angle = op.getParameter(0); + + if (type == "rx") { + return createAxisQuaternion(angle, RotationAxis::X, loc, rewriter); + } + if (type == "ry") { + return createAxisQuaternion(angle, RotationAxis::Y, loc, rewriter); + } + if (type == "rz") { + return createAxisQuaternion(angle, RotationAxis::Z, loc, rewriter); + } + llvm_unreachable("Unsupported operation type"); + } + + /** + * @brief Computes the Hamilton product of two quaternions (q1 * q2). + * + * For q1 = w1 + x1*i + y1*j + z1*k and q2 = w2 + x2*i + y2*j + z2*k: + * + * q1 * q2 = (w1w2 - x1x2 - y1y2 - z1z2) + * + (w1x2 + x1w2 + y1z2 - z1y2) * i + * + (w1y2 - x1z2 + y1w2 + z1x2) * j + * + (w1z2 + x1y2 - y1x2 + z1w2) * k + * + * @see https://en.wikipedia.org/wiki/Quaternion#Hamilton_product + * @param q1 First quaternion + * @param q2 Second quaternion + * @param op Current operation (used for location) + * @param rewriter Pattern rewriter for creating arithmetic operations + * @return The product quaternion + */ + static Quaternion hamiltonProduct(Quaternion q1, Quaternion q2, + UnitaryOpInterface op, + mlir::PatternRewriter& rewriter) { + auto loc = op->getLoc(); + + // wRes = w1w2 - x1x2 - y1y2 - z1z2 + auto w1w2 = rewriter.create(loc, q1.w, q2.w); + auto x1x2 = rewriter.create(loc, q1.x, q2.x); + auto y1y2 = rewriter.create(loc, q1.y, q2.y); + auto z1z2 = rewriter.create(loc, q1.z, q2.z); + auto wTemp1 = rewriter.create(loc, w1w2, x1x2); + auto wTemp2 = rewriter.create(loc, wTemp1, y1y2); + auto wRes = rewriter.create(loc, wTemp2, z1z2); + + // xRes = w1x2 + x1w2 + y1z2 - z1y2 + auto w1x2 = rewriter.create(loc, q1.w, q2.x); + auto x1w2 = rewriter.create(loc, q1.x, q2.w); + auto y1z2 = rewriter.create(loc, q1.y, q2.z); + auto z1y2 = rewriter.create(loc, q1.z, q2.y); + auto xTemp1 = rewriter.create(loc, w1x2, x1w2); + auto xTemp2 = rewriter.create(loc, xTemp1, y1z2); + auto xRes = rewriter.create(loc, xTemp2, z1y2); + + // yRes = w1y2 - x1z2 + y1w2 + z1x2 + auto w1y2 = rewriter.create(loc, q1.w, q2.y); + auto x1z2 = rewriter.create(loc, q1.x, q2.z); + auto y1w2 = rewriter.create(loc, q1.y, q2.w); + auto z1x2 = rewriter.create(loc, q1.z, q2.x); + auto yTemp1 = rewriter.create(loc, w1y2, x1z2); + auto yTemp2 = rewriter.create(loc, yTemp1, y1w2); + auto yRes = rewriter.create(loc, yTemp2, z1x2); + + // zRes = w1z2 + x1y2 - y1x2 + z1w2 + auto w1z2 = rewriter.create(loc, q1.w, q2.z); + auto x1y2 = rewriter.create(loc, q1.x, q2.y); + auto y1x2 = rewriter.create(loc, q1.y, q2.x); + auto z1w2 = rewriter.create(loc, q1.z, q2.w); + auto zTemp1 = rewriter.create(loc, w1z2, x1y2); + auto zTemp2 = rewriter.create(loc, zTemp1, y1x2); + auto zRes = rewriter.create(loc, zTemp2, z1w2); + + return {.w = wRes, .x = xRes, .y = yRes, .z = zRes}; + } + + /** + * @brief Converts a u-gate to quaternion representation. + * + * U(alpha, beta, gamma) uses ZYZ decomposition: RZ(alpha) -> RY(beta) -> + * RZ(gamma). + * + * When composing rotations, quaternion multiplication follows matrix + * multiplication order (right-to-left), which is the reverse of the + * application sequence: + * Sequential application: RZ(alpha), then RY(beta), then RZ(gamma) + * Quaternion product: Qgamma * Qbeta * Qalpha + * + * @param op The u-gate operation to convert + * @param rewriter Pattern rewriter for creating new operations + * @return Quaternion representing the u-gate + */ + static Quaternion quaternionFromUGate(UnitaryOpInterface op, + mlir::PatternRewriter& rewriter) { + auto loc = op->getLoc(); + + // U gate uses ZYZ decomposition: + // U(alpha, beta, gamma) = Rz(alpha) -> Ry(beta) -> Rz(gamma) + auto qAlpha = createAxisQuaternion(op.getParameter(0), RotationAxis::Z, loc, + rewriter); + auto qBeta = createAxisQuaternion(op.getParameter(1), RotationAxis::Y, loc, + rewriter); + auto qGamma = createAxisQuaternion(op.getParameter(2), RotationAxis::Z, loc, + rewriter); + + // qGamma * qBeta * qAlpha (multiplication in reverse order!) + auto temp = hamiltonProduct(qGamma, qBeta, op, rewriter); + return hamiltonProduct(temp, qAlpha, op, rewriter); + } + + /** + * @brief Converts a quaternion to a u-gate using ZYZ Euler angle extraction. + * + * For unit quaternion q = w + x*i + y*j + z*k, extracts u-gate parameters: + * alpha = atan2(z, w) - atan2(-x, y) + * beta = acos(2 * (w^2 + z^2) - 1) + * gamma = atan2(z, w) + atan2(-x, y) + * + * Based on Bernardes & Viollet (2022), simplified for unit quaternions and + * proper ZYZ Euler angles (Chapter 3.3): + * https://doi.org/10.1371/journal.pone.0276302 + * + * Reference implementation: + * https://github.com/evbernardes/quaternion_to_euler + * SymPy also implements this paper: + * https://docs.sympy.org/latest/modules/algebras.html#sympy.algebras.Quaternion.to_euler + * + * @note Floating-point errors may accumulate when merging many gates. + * @param q The quaternion to convert + * @param op The current operation (used for location and type information) + * @param rewriter Pattern rewriter for creating new operations + * @return U-gate equivalent to the quaternion rotation + */ + static UnitaryOpInterface + uGateFromQuaternion(Quaternion q, UnitaryOpInterface op, + mlir::PatternRewriter& rewriter) { + auto loc = op->getLoc(); + + auto floatType = op.getParameter(0).getType(); + // constant 1.0 + auto oneAttr = rewriter.getFloatAttr(floatType, 1.0); + auto one = rewriter.create(loc, oneAttr); + // constant 2.0 + auto twoAttr = rewriter.getFloatAttr(floatType, 2.0); + auto two = rewriter.create(loc, twoAttr); + + // calculate angle beta (for y-rotation) + // beta = acos(2 * (w^2 + z^2) - 1) + auto ww = rewriter.create(loc, q.w, q.w); + auto zz = rewriter.create(loc, q.z, q.z); + auto bTemp1 = rewriter.create(loc, ww, zz); + auto bTemp2 = rewriter.create(loc, two, bTemp1); + auto bTemp3 = rewriter.create(loc, bTemp2, one); + auto beta = rewriter.create(loc, bTemp3); + + // intermediate angles for z-rotations alpha and gamma + // theta+ = atan2(z, w) + // theta- = atan2(-x, y) + auto xMinus = rewriter.create(loc, q.x); + auto thetaPlus = rewriter.create(loc, q.z, q.w); + auto thetaMinus = rewriter.create(loc, xMinus, q.y); + + // z-rotations alpha and gamma + // alpha = theta+ - theta- + // gamma = theta+ + theta- + auto alpha = + rewriter.create(loc, thetaPlus, thetaMinus); + auto gamma = + rewriter.create(loc, thetaPlus, thetaMinus); + + return rewriter.create(loc, op.getInputQubit(0), alpha.getResult(), + beta.getResult(), gamma.getResult()); + } + + /** + * @brief Creates a u-gate by merging two rotation gates. + * + * Converts both gates to quaternions, multiplies them using the Hamilton + * product (in reverse circuit order), and converts back to a u-gate. + * + * @param op The first rotation gate + * @param user The second rotation gate + * @param rewriter Pattern rewriter for creating the merged gate + * @return A u-gate representing the merged rotation + */ + static UnitaryOpInterface + createOpQuaternionMergedAngle(UnitaryOpInterface op, UnitaryOpInterface user, + mlir::PatternRewriter& rewriter) { + auto q1 = quaternionFromRotation(op, rewriter); + auto q2 = quaternionFromRotation(user, rewriter); + auto qHam = hamiltonProduct(q2, q1, op, rewriter); + auto newUser = uGateFromQuaternion(qHam, op, rewriter); + + return newUser; + } + + /** + * @brief Matches and merges consecutive rotation gates on the same qubit. + * + * Merges two gates using quaternion multiplication when the first gate has + * exactly one use, replacing both with an equivalent u-gate. + * + * @param op The rotation gate to match + * @param rewriter Pattern rewriter for applying transformations + * @return success() if gates were merged, failure() otherwise + */ + mlir::LogicalResult + matchAndRewrite(UnitaryOpInterface op, + mlir::PatternRewriter& rewriter) const override { + // QCO operations cannot contain control qubits, so no need to check for + // them + if (!op->hasOneUse()) { + return mlir::failure(); + } + + const auto& users = op->getUsers(); + auto* userOP = *users.begin(); + + if (!areQuaternionMergeable(*op, *userOP)) { + return mlir::failure(); + } + auto user = mlir::dyn_cast(userOP); + if (!user) { + return mlir::failure(); + } + + UnitaryOpInterface newUser = + createOpQuaternionMergedAngle(op, user, rewriter); + + // Replace user with newUser + rewriter.replaceOp(user, newUser); + + // Erase op + rewriter.eraseOp(op); + return mlir::success(); + } +}; + +/** + * @brief Populates the given pattern set with the `MergeRotationGatesPattern`. + * + * @param patterns The pattern set to populate + */ +static void +populateMergeRotationGatesPatterns(mlir::RewritePatternSet& patterns) { + patterns.add(patterns.getContext()); +} + +/** + * @brief Pass that merges consecutive rotation gates using quaternion + * multiplication. + */ +struct MergeRotationGates final + : impl::MergeRotationGatesBase { + using impl::MergeRotationGatesBase< + MergeRotationGates>::MergeRotationGatesBase; + + void runOnOperation() override { + // Get the current operation being operated on. + auto op = getOperation(); + auto* ctx = &getContext(); + + // Define the set of patterns to use. + mlir::RewritePatternSet patterns(ctx); + populateMergeRotationGatesPatterns(patterns); + + // Apply patterns in an iterative and greedy manner. + if (mlir::failed(mlir::applyPatternsGreedily(op, std::move(patterns)))) { + signalPassFailure(); + } + } +}; + +} // namespace mlir::qco diff --git a/mlir/tools/mqt-cc/mqt-cc.cpp b/mlir/tools/mqt-cc/mqt-cc.cpp index fd25e8d5e..9b551f563 100644 --- a/mlir/tools/mqt-cc/mqt-cc.cpp +++ b/mlir/tools/mqt-cc/mqt-cc.cpp @@ -72,6 +72,10 @@ const cl::opt cl::desc("Print IR after each compiler stage"), cl::init(false)); +const cl::opt MERGE_ROTATION_GATES( + "mlir-merge-rotation-gates", + cl::desc("Enable quaternion-based rotation gate merging"), cl::init(false)); + } // namespace /** @@ -168,6 +172,7 @@ int main(int argc, char** argv) { config.enableTiming = ENABLE_TIMING; config.enableStatistics = ENABLE_STATISTICS; config.printIRAfterAllStages = PRINT_IR_AFTER_ALL_STAGES; + config.mergeRotationGates = MERGE_ROTATION_GATES; // Run the compilation pipeline CompilationRecord record; diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index 86b0e5916..fddb033ad 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -13,4 +13,5 @@ add_custom_target(mqt-core-mlir-unittests) add_dependencies( mqt-core-mlir-unittests mqt-core-mlir-compiler-pipeline-test mqt-core-mlir-qco-dialect-test - mqt-core-mlir-dialect-qco-ir-modifiers-test mqt-core-mlir-dialect-utils-test) + mqt-core-mlir-dialect-qco-ir-modifiers-test mqt-core-mlir-dialect-utils-test + mqt-core-mlir-dialect-qco-transforms-test) diff --git a/mlir/unittests/Dialect/QCO/CMakeLists.txt b/mlir/unittests/Dialect/QCO/CMakeLists.txt index b181a84fe..3b0a561d0 100644 --- a/mlir/unittests/Dialect/QCO/CMakeLists.txt +++ b/mlir/unittests/Dialect/QCO/CMakeLists.txt @@ -7,3 +7,4 @@ # Licensed under the MIT License add_subdirectory(IR) +add_subdirectory(Transforms) diff --git a/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt b/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt new file mode 100644 index 000000000..200ecccea --- /dev/null +++ b/mlir/unittests/Dialect/QCO/Transforms/CMakeLists.txt @@ -0,0 +1,22 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_executable(mqt-core-mlir-dialect-qco-transforms-test test_qco_quaternion_merge.cpp) + +target_link_libraries( + mqt-core-mlir-dialect-qco-transforms-test + # TODO figure out correct dependencies + PRIVATE GTest::gtest_main + MLIRQCOProgramBuilder + MLIRQCOTransforms + MLIRIR + MLIRPass # for PassManager + MLIRSupport + LLVMSupport) + +gtest_discover_tests(mqt-core-mlir-dialect-qco-transforms-test) diff --git a/mlir/unittests/Dialect/QCO/Transforms/test_qco_quaternion_merge.cpp b/mlir/unittests/Dialect/QCO/Transforms/test_qco_quaternion_merge.cpp new file mode 100644 index 000000000..e8202ec62 --- /dev/null +++ b/mlir/unittests/Dialect/QCO/Transforms/test_qco_quaternion_merge.cpp @@ -0,0 +1,605 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/QCO/Builder/QCOProgramBuilder.h" +#include "mlir/Dialect/QCO/IR/QCODialect.h" +#include "mlir/Dialect/QCO/Transforms/Passes.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace { + +using namespace mlir; +using namespace mlir::qco; + +class QCOQuaternionMergeTest : public ::testing::Test { +protected: + MLIRContext context; + QCOProgramBuilder builder; + OwningOpRef module; + + /** + * @brief Struct to easily construct a rotation gate inline. + * opName uses the getOperationName() mnemonic. + */ + struct RotationGate { + llvm::StringLiteral opName; + std::vector angles; + }; + + QCOQuaternionMergeTest() : builder(&context) {} + + void SetUp() override { + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + + builder.initialize(); + } + + /** + * @brief Counts the amount of operations the current module/circuit + * contains. + */ + template int countOps() { + int count = 0; + module->walk([&](OpTy) { ++count; }); + return count; + } + + /** + * @brief Extract constant floating point value from a mlir::Value + */ + std::optional toDouble(mlir::Value v) { + if (auto constOp = v.getDefiningOp()) { + if (auto floatAttr = + mlir::dyn_cast(constOp.getValue())) { + return floatAttr.getValueAsDouble(); + } + } + return std::nullopt; + } + + /** + * @brief Find the first occurrence of a u-gate in the current module and get + * the numeric value of its parameters. This assumes that parameters + * are constant and can be extracted. + */ + std::optional> getUGateParams() { + UOp uOp = nullptr; + module->walk([&](UOp op) { + uOp = op; + // stop after finding first UOp + return mlir::WalkResult::interrupt(); + }); + + if (!uOp) { + return std::nullopt; + } + + auto theta = toDouble(uOp.getTheta()); + auto phi = toDouble(uOp.getPhi()); + auto lambda = toDouble(uOp.getLambda()); + + if (!theta || !phi || !lambda) { + return std::nullopt; + } + + return std::make_tuple(*theta, *phi, *lambda); + } + + /** + * @brief Gets the first u-gate of a module and tests whether its + * angle parameters are equal the the expected ones. + */ + void expectUGateParams(double expectedTheta, double expectedPhi, + double expectedLambda, double tolerance = 1e-8) { + auto params = getUGateParams(); + ASSERT_TRUE(params.has_value()); + + auto [theta, phi, lambda] = *params; + EXPECT_NEAR(theta, expectedTheta, tolerance); + EXPECT_NEAR(phi, expectedPhi, tolerance); + EXPECT_NEAR(lambda, expectedLambda, tolerance); + } + + /** + * @brief Takes a list of rotation gates (rx, ry, rz and u) and uses the + * builder api to build a small quantum circuit, where a qubit is feed through + * all rotations in the list. + */ + LogicalResult testGateMerge(const std::vector& rotations) { + + auto q = builder.allocQubitRegister(1); + + Value qubit = q[0]; + + for (const auto& gate : rotations) { + if (gate.opName == RXOp::getOperationName()) { + qubit = builder.rx(gate.angles[0], qubit); + } else if (gate.opName == RYOp::getOperationName()) { + qubit = builder.ry(gate.angles[0], qubit); + } else if (gate.opName == RZOp::getOperationName()) { + qubit = builder.rz(gate.angles[0], qubit); + } else if (gate.opName == UOp::getOperationName()) { + qubit = + builder.u(gate.angles[0], gate.angles[1], gate.angles[2], qubit); + } + } + + module = builder.finalize(); + return runMergePass(module.get()); + } + + /** + * @brief Adds the mergeRotationGates Pass to the current context and runs it. + */ + LogicalResult runMergePass(ModuleOp module) { + PassManager pm(module.getContext()); + pm.addPass(qco::createMergeRotationGates()); + return pm.run(module); + } +}; + +} // namespace + +// ################################################## +// # Two Gate Merging Tests +// ################################################## + +/** + * @brief Test: RX->RY should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeRXRYGates) { + ASSERT_TRUE(testGateMerge({{RXOp::getOperationName(), {1.}}, + {RYOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: RX->RZ should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeRXRZGates) { + ASSERT_TRUE(testGateMerge({{RXOp::getOperationName(), {1.}}, + {RZOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: RY->RX should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeRYRXGates) { + ASSERT_TRUE(testGateMerge({{RYOp::getOperationName(), {1.}}, + {RXOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: RY->RZ should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeRYRZGates) { + ASSERT_TRUE(testGateMerge({{RYOp::getOperationName(), {1.}}, + {RZOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: RZ->RX should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeRZRXGates) { + ASSERT_TRUE(testGateMerge({{RZOp::getOperationName(), {1.}}, + {RXOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: RZ->RY should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeRZRYGates) { + ASSERT_TRUE(testGateMerge({{RZOp::getOperationName(), {1.}}, + {RYOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: U->U should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeUUGates) { + ASSERT_TRUE(testGateMerge({{UOp::getOperationName(), {1., 2., .3}}, + {UOp::getOperationName(), {4., 5., 6.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: U->RX should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeURXGates) { + ASSERT_TRUE(testGateMerge({{UOp::getOperationName(), {1., 2., .3}}, + {RXOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: U->RY should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeURYGates) { + ASSERT_TRUE(testGateMerge({{UOp::getOperationName(), {1., 2., .3}}, + {RYOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: U->RZ should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeURZGates) { + ASSERT_TRUE(testGateMerge({{UOp::getOperationName(), {1., 2., .3}}, + {RZOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: RX->U should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeRXUGates) { + ASSERT_TRUE(testGateMerge({{RXOp::getOperationName(), {1.}}, + {UOp::getOperationName(), {1., 2., .3}}}) + .succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: RY->U should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeRYUGates) { + ASSERT_TRUE(testGateMerge({{RYOp::getOperationName(), {1.}}, + {UOp::getOperationName(), {1., 2., .3}}}) + .succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: RZ->U should merge into a single U gate + */ +TEST_F(QCOQuaternionMergeTest, quaternionMergeRZUGates) { + ASSERT_TRUE(testGateMerge({{RZOp::getOperationName(), {1.}}, + {UOp::getOperationName(), {1., 2., .3}}}) + .succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +// ################################################## +// # Not Merging Tests +// ################################################## + +/** + * @brief Test: RX->RX should not merge + */ +TEST_F(QCOQuaternionMergeTest, quaternionNoMergeRXRXGates) { + ASSERT_TRUE(testGateMerge({{RXOp::getOperationName(), {1.}}, + {RXOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 2); +} + +/** + * @brief Test: RY->RY should not merge + */ +TEST_F(QCOQuaternionMergeTest, quaternionNoMergeRYRYGates) { + ASSERT_TRUE(testGateMerge({{RYOp::getOperationName(), {1.}}, + {RYOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 2); +} + +/** + * @brief Test: RZ->RZ should not merge + */ +TEST_F(QCOQuaternionMergeTest, quaternionNoMergeRZRZGates) { + ASSERT_TRUE(testGateMerge({{RZOp::getOperationName(), {1.}}, + {RZOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 2); +} + +/** + * @brief Test: single RX should not convert to U + */ +TEST_F(QCOQuaternionMergeTest, quaternionNoMergeSingleRXGate) { + ASSERT_TRUE(testGateMerge({{RXOp::getOperationName(), {1.}}}).succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: single RY should not convert to U + */ +TEST_F(QCOQuaternionMergeTest, quaternionNoMergeSingleRYGate) { + ASSERT_TRUE(testGateMerge({{RYOp::getOperationName(), {1.}}}).succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: single RZ should not convert to U + */ +TEST_F(QCOQuaternionMergeTest, quaternionNoMergeSingleRZGate) { + ASSERT_TRUE(testGateMerge({{RZOp::getOperationName(), {1.}}}).succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: Gates on different qubits should not merge + */ +TEST_F(QCOQuaternionMergeTest, dontMergeGatesFromDifferentQubits) { + auto q = builder.allocQubitRegister(2); + + Value qubit1 = q[0]; + Value qubit2 = q[1]; + builder.rx(1.0, qubit1); + builder.rx(1.0, qubit2); + module = builder.finalize(); + + ASSERT_TRUE(runMergePass(module.get()).succeeded()); + EXPECT_EQ(countOps(), 2); +} + +/** + * @brief Test: Non-consecutive gates should not merge + */ +TEST_F(QCOQuaternionMergeTest, dontMergeNonConsecutiveGates) { + auto q = builder.allocQubitRegister(1); + + auto q1 = builder.rx(1.0, q[0]); + auto q2 = builder.h(q1); + builder.ry(1.0, q2); + + module = builder.finalize(); + + ASSERT_TRUE(runMergePass(module.get()).succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 1); +} + +// ################################################## +// # Special Cases Tests +// ################################################## + +/** + * @brief Test: Consecutive gates with another gate in between should merge + */ +TEST_F(QCOQuaternionMergeTest, mergeConsecutiveWithGateInBetween) { + auto q = builder.allocQubitRegister(2); + + auto q1 = builder.rx(1.0, q[0]); + builder.h(q[1]); + builder.ry(1.0, q1); + + module = builder.finalize(); + + ASSERT_TRUE(runMergePass(module.get()).succeeded()); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 1); +} + +/** + * @brief Test: Gates with multiple uses should not be merged but pass should + * still succeed + */ +TEST_F(QCOQuaternionMergeTest, nonLinearCodeHandling) { + // QCOProgramBuilder does not allow non-linear circuits, + // so strings are used + const char* mlirCode = R"( + module { + func.func @nonLinearCodeHandling() { + %0 = qco.alloc : !qco.qubit + %cst = arith.constant 1.000000e+00 : f64 + %1 = qco.ry(%cst) %0 : !qco.qubit -> !qco.qubit + + // %1 is used by BOTH operations - violates linearity! + %2 = qco.rz(%cst) %1 : !qco.qubit -> !qco.qubit + %3 = qco.rz(%cst) %1 : !qco.qubit -> !qco.qubit + + qco.dealloc %2 : !qco.qubit + qco.dealloc %3 : !qco.qubit + return + } + } + )"; + + module = mlir::parseSourceString(mlirCode, &context); + ASSERT_TRUE(module); + + ASSERT_TRUE(runMergePass(module.get()).succeeded()); + + // Gates should remain unchanged (not merged) due to multiple uses + EXPECT_EQ(countOps(), 2); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); +} + +/** + * @brief Test: Gates with no final users should still succeed + * but will be removed by dead code removal from + * applyPatternsGreedily + */ +TEST_F(QCOQuaternionMergeTest, noUsedGate) { + const char* mlirCode = R"( + module { + func.func @noUsedGate() { + %0 = qco.alloc : !qco.qubit + %cst = arith.constant 1.000000e+00 : f64 + %1 = qco.ry(%cst) %0 : !qco.qubit -> !qco.qubit + %2 = qco.rz(%cst) %1 : !qco.qubit -> !qco.qubit + return + } + } + )"; + + module = mlir::parseSourceString(mlirCode, &context); + ASSERT_TRUE(module); + + ASSERT_TRUE(runMergePass(module.get()).succeeded()); + + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); +} + +// ################################################## +// # Numerical Correctness +// ################################################## + +/** + * @brief Test: RX(1)->RY(1) should merge into + * U(0.495367289218673, 1.27455578230629, -1.07542903757622) + */ +TEST_F(QCOQuaternionMergeTest, numericalAccuracyRXRY) { + ASSERT_TRUE(testGateMerge({{RXOp::getOperationName(), {1.}}, + {RYOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + + expectUGateParams(0.495367289218673, 1.27455578230629, -1.07542903757622); +} + +/** + * @brief Test: RX(1)->RZ(1) should merge into + * U(1.57079632679490, 1.00000000000000, -0.570796326794897) + */ +TEST_F(QCOQuaternionMergeTest, numericalAccuracyRXRZ) { + ASSERT_TRUE(testGateMerge({{RXOp::getOperationName(), {1.}}, + {RZOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + + expectUGateParams(1.57079632679490, 1.00000000000000, -0.570796326794897); +} + +/** + * @brief Test: RY(1)->RX(1) should merge into + * U(1.07542903757622, 1.27455578230629, -0.495367289218673) + */ +TEST_F(QCOQuaternionMergeTest, numericalAccuracyRYRX) { + ASSERT_TRUE(testGateMerge({{RYOp::getOperationName(), {1.}}, + {RXOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + + expectUGateParams(1.07542903757622, 1.27455578230629, -0.495367289218673); +} + +/** + * @brief Test: RY(1)->RZ(1) should merge into + * U(0, 1.00000000000000, 1.00000000000000) + */ +TEST_F(QCOQuaternionMergeTest, numericalAccuracyRYRZ) { + ASSERT_TRUE(testGateMerge({{RYOp::getOperationName(), {1.}}, + {RZOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + + expectUGateParams(0, 1.00000000000000, 1.00000000000000); +} + +/** + * @brief Test: RZ(1)->RX(1) should merge into + * U(2.57079632679490, 1.00000000000000, -1.57079632679490) + */ +TEST_F(QCOQuaternionMergeTest, numericalAccuracyRZRX) { + ASSERT_TRUE(testGateMerge({{RZOp::getOperationName(), {1.}}, + {RXOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + + expectUGateParams(2.57079632679490, 1.00000000000000, -1.57079632679490); +} + +/** + * @brief Test: RZ(1)->RY(1) should merge into + * U(1.00000000000000, 1.00000000000000, 0) + */ +TEST_F(QCOQuaternionMergeTest, numericalAccuracyRZRY) { + ASSERT_TRUE(testGateMerge({{RZOp::getOperationName(), {1.}}, + {RYOp::getOperationName(), {1.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + EXPECT_EQ(countOps(), 0); + EXPECT_EQ(countOps(), 0); + + expectUGateParams(1.00000000000000, 1.00000000000000, 0); +} + +/** + * @brief Test: U(1,2,3)->U(4,5,6) should merge into + * U(0.154763313125030, 1.00116934013043, -5.77770904175559) + */ +TEST_F(QCOQuaternionMergeTest, numericalAccuracyUU) { + ASSERT_TRUE(testGateMerge({{UOp::getOperationName(), {1., 2., 3.}}, + {UOp::getOperationName(), {4., 5., 6.}}}) + .succeeded()); + EXPECT_EQ(countOps(), 1); + + expectUGateParams(0.154763313125030, 1.00116934013043, -5.77770904175559); +}