diff --git a/CHANGELOG.md b/CHANGELOG.md index fa7dc54f9..71ff911cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel ### Added -- ✨ Add initial infrastructure for new QC and QCO MLIR dialects ([#1264], [#1402], [#1428], [#1430], [#1436]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**]) +- ✨ Add initial infrastructure for new QC and QCO MLIR dialects ([#1264], [#1402], [#1428], [#1430], [#1436], [#1443]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**]) ### Changed @@ -307,6 +307,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool +[#1443]: https://github.com/munich-quantum-toolkit/core/pull/1443 [#1437]: https://github.com/munich-quantum-toolkit/core/pull/1437 [#1436]: https://github.com/munich-quantum-toolkit/core/pull/1436 [#1430]: https://github.com/munich-quantum-toolkit/core/pull/1430 diff --git a/mlir/include/mlir/Dialect/QC/IR/QCDialect.h b/mlir/include/mlir/Dialect/QC/IR/QCDialect.h index f4fc35a5a..98e9907de 100644 --- a/mlir/include/mlir/Dialect/QC/IR/QCDialect.h +++ b/mlir/include/mlir/Dialect/QC/IR/QCDialect.h @@ -101,14 +101,6 @@ template class TargetAndParameterArityTrait { } return this->getOperation()->getOperand(T + i); } - - [[nodiscard]] static FloatAttr getStaticParameter(Value param) { - auto constantOp = param.getDefiningOp(); - if (!constantOp) { - return nullptr; - } - return dyn_cast(constantOp.getValue()); - } }; }; diff --git a/mlir/include/mlir/Dialect/QCO/IR/QCODialect.h b/mlir/include/mlir/Dialect/QCO/IR/QCODialect.h index c429707cc..11fdbb5a6 100644 --- a/mlir/include/mlir/Dialect/QCO/IR/QCODialect.h +++ b/mlir/include/mlir/Dialect/QCO/IR/QCODialect.h @@ -108,14 +108,6 @@ template class TargetAndParameterArityTrait { return this->getOperation()->getOperand(T + i); } - [[nodiscard]] static FloatAttr getStaticParameter(Value param) { - auto constantOp = param.getDefiningOp(); - if (!constantOp) { - return nullptr; - } - return dyn_cast(constantOp.getValue()); - } - Value getInputForOutput(Value output) { const auto& op = this->getOperation(); for (size_t i = 0; i < T; ++i) { diff --git a/mlir/include/mlir/Dialect/QCO/QCOUtils.h b/mlir/include/mlir/Dialect/QCO/QCOUtils.h index 3c2f4564a..009852483 100644 --- a/mlir/include/mlir/Dialect/QCO/QCOUtils.h +++ b/mlir/include/mlir/Dialect/QCO/QCOUtils.h @@ -10,8 +10,9 @@ #pragma once +#include "mlir/Dialect/Utils/Utils.h" + #include -#include #include namespace mlir::qco { @@ -188,13 +189,8 @@ mergeTwoTargetOneParameter(OpType op, mlir::PatternRewriter& rewriter) { template inline mlir::LogicalResult removeTrivialOneTargetOneParameter(OpType op, mlir::PatternRewriter& rewriter) { - const auto paramAttr = OpType::getStaticParameter(op.getOperand(1)); - if (!paramAttr) { - return failure(); - } - - const auto paramValue = paramAttr.getValueAsDouble(); - if (std::abs(paramValue) > utils::TOLERANCE) { + const auto param = utils::valueToDouble(op.getOperand(1)); + if (!param || std::abs(*param) > utils::TOLERANCE) { return failure(); } @@ -215,13 +211,8 @@ removeTrivialOneTargetOneParameter(OpType op, mlir::PatternRewriter& rewriter) { template inline mlir::LogicalResult removeTrivialTwoTargetOneParameter(OpType op, mlir::PatternRewriter& rewriter) { - const auto paramAttr = OpType::getStaticParameter(op.getOperand(2)); - if (!paramAttr) { - return failure(); - } - - const auto paramValue = paramAttr.getValueAsDouble(); - if (std::abs(paramValue) > utils::TOLERANCE) { + const auto param = utils::valueToDouble(op.getOperand(2)); + if (!param || std::abs(*param) > utils::TOLERANCE) { return failure(); } diff --git a/mlir/include/mlir/Dialect/Utils/Utils.h b/mlir/include/mlir/Dialect/Utils/Utils.h index 08f6412f7..1674c5779 100644 --- a/mlir/include/mlir/Dialect/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Utils/Utils.h @@ -28,8 +28,9 @@ constexpr auto TOLERANCE = 1e-15; * @param parameter The parameter as a variant (double or Value). * @return Value The parameter as a Value. */ -inline Value variantToValue(OpBuilder& builder, const OperationState& state, - const std::variant& parameter) { +[[nodiscard]] inline Value +variantToValue(OpBuilder& builder, const OperationState& state, + const std::variant& parameter) { Value operand; if (std::holds_alternative(parameter)) { operand = builder.create( @@ -40,4 +41,32 @@ inline Value variantToValue(OpBuilder& builder, const OperationState& state, return operand; } +/** + * @brief Try to convert a mlir::Value to a standard C++ double + * + * @details + * Resolving the mlir::Value will only work if it is a static value, so a value + * defined via a "arith.constant" operation. It must also be of type + * float or integer. + */ +[[nodiscard]] inline std::optional valueToDouble(Value value) { + auto constantOp = value.getDefiningOp(); + if (!constantOp) { + return std::nullopt; + } + auto floatAttr = dyn_cast(constantOp.getValue()); + if (floatAttr) { + return floatAttr.getValueAsDouble(); + } + auto intAttr = dyn_cast(constantOp.getValue()); + if (intAttr) { + if (intAttr.getType().isUnsignedInteger()) { + return static_cast(intAttr.getValue().getZExtValue()); + } + // interpret both signed+signless as signed integers + return static_cast(intAttr.getValue().getSExtValue()); + } + return std::nullopt; +} + } // namespace mlir::utils diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/GPhaseOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/GPhaseOp.cpp index 988535161..a9f978199 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/GPhaseOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/GPhaseOp.cpp @@ -13,7 +13,6 @@ #include #include -#include #include #include #include @@ -34,13 +33,8 @@ struct RemoveTrivialGPhase final : OpRewritePattern { LogicalResult matchAndRewrite(GPhaseOp op, PatternRewriter& rewriter) const override { - const auto thetaAttr = GPhaseOp::getStaticParameter(op.getTheta()); - if (!thetaAttr) { - return failure(); - } - - const auto thetaValue = thetaAttr.getValueAsDouble(); - if (std::abs(thetaValue) > TOLERANCE) { + const auto theta = valueToDouble(op.getTheta()); + if (!theta || std::abs(*theta) > TOLERANCE) { return failure(); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/ROp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/ROp.cpp index 28c845f26..2f798d1e0 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/ROp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/ROp.cpp @@ -13,7 +13,6 @@ #include #include -#include #include #include #include @@ -35,13 +34,8 @@ struct ReplaceRWithRX final : OpRewritePattern { LogicalResult matchAndRewrite(ROp op, PatternRewriter& rewriter) const override { - const auto phi = ROp::getStaticParameter(op.getPhi()); - if (!phi) { - return failure(); - } - - const auto phiValue = phi.getValueAsDouble(); - if (std::abs(phiValue) > TOLERANCE) { + const auto phi = valueToDouble(op.getPhi()); + if (!phi || std::abs(*phi) > TOLERANCE) { return failure(); } @@ -61,13 +55,8 @@ struct ReplaceRWithRY final : OpRewritePattern { LogicalResult matchAndRewrite(ROp op, PatternRewriter& rewriter) const override { - const auto phi = ROp::getStaticParameter(op.getPhi()); - if (!phi) { - return failure(); - } - - const auto phiValue = phi.getValueAsDouble(); - if (std::abs(phiValue - (std::numbers::pi / 2.0)) > TOLERANCE) { + const auto phi = valueToDouble(op.getPhi()); + if (!phi || std::abs(*phi - (std::numbers::pi / 2.0)) > TOLERANCE) { return failure(); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/U2Op.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/U2Op.cpp index 5be74a39e..ced629792 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/U2Op.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/U2Op.cpp @@ -34,16 +34,10 @@ struct ReplaceU2WithH final : OpRewritePattern { LogicalResult matchAndRewrite(U2Op op, PatternRewriter& rewriter) const override { - const auto phi = U2Op::getStaticParameter(op.getPhi()); - const auto lambda = U2Op::getStaticParameter(op.getLambda()); - if (!phi || !lambda) { - return failure(); - } - - const auto phiValue = phi.getValueAsDouble(); - const auto lambdaValue = lambda.getValueAsDouble(); - if (std::abs(phiValue) > TOLERANCE || - std::abs(lambdaValue - std::numbers::pi) > TOLERANCE) { + const auto phi = valueToDouble(op.getPhi()); + const auto lambda = valueToDouble(op.getLambda()); + if (!phi || std::abs(*phi) > TOLERANCE || !lambda || + std::abs(*lambda - std::numbers::pi) > TOLERANCE) { return failure(); } @@ -62,16 +56,10 @@ struct ReplaceU2WithRX final : OpRewritePattern { LogicalResult matchAndRewrite(U2Op op, PatternRewriter& rewriter) const override { - const auto phi = U2Op::getStaticParameter(op.getPhi()); - const auto lambda = U2Op::getStaticParameter(op.getLambda()); - if (!phi || !lambda) { - return failure(); - } - - const auto phiValue = phi.getValueAsDouble(); - const auto lambdaValue = lambda.getValueAsDouble(); - if (std::abs(phiValue + (std::numbers::pi / 2.0)) > TOLERANCE || - std::abs(lambdaValue - (std::numbers::pi / 2.0)) > TOLERANCE) { + const auto phi = valueToDouble(op.getPhi()); + const auto lambda = valueToDouble(op.getLambda()); + if (!phi || std::abs(*phi + (std::numbers::pi / 2.0)) > TOLERANCE || + !lambda || std::abs(*lambda - (std::numbers::pi / 2.0)) > TOLERANCE) { return failure(); } @@ -91,15 +79,10 @@ struct ReplaceU2WithRY final : OpRewritePattern { LogicalResult matchAndRewrite(U2Op op, PatternRewriter& rewriter) const override { - const auto phi = U2Op::getStaticParameter(op.getPhi()); - const auto lambda = U2Op::getStaticParameter(op.getLambda()); - if (!phi || !lambda) { - return failure(); - } - - const auto phiValue = phi.getValueAsDouble(); - const auto lambdaValue = lambda.getValueAsDouble(); - if (std::abs(phiValue) > TOLERANCE || std::abs(lambdaValue) > TOLERANCE) { + const auto phi = valueToDouble(op.getPhi()); + const auto lambda = valueToDouble(op.getLambda()); + if (!phi || std::abs(*phi) > TOLERANCE || !lambda || + std::abs(*lambda) > TOLERANCE) { return failure(); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/UOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/UOp.cpp index 8a5e02104..fcf69025a 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/UOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/UOp.cpp @@ -13,7 +13,6 @@ #include #include -#include #include #include #include @@ -35,15 +34,10 @@ struct ReplaceUWithP final : OpRewritePattern { LogicalResult matchAndRewrite(UOp op, PatternRewriter& rewriter) const override { - const auto theta = UOp::getStaticParameter(op.getTheta()); - const auto phi = UOp::getStaticParameter(op.getPhi()); - if (!theta || !phi) { - return failure(); - } - - const auto thetaValue = theta.getValueAsDouble(); - const auto phiValue = phi.getValueAsDouble(); - if (std::abs(thetaValue) > TOLERANCE || std::abs(phiValue) > TOLERANCE) { + const auto theta = valueToDouble(op.getTheta()); + const auto phi = valueToDouble(op.getPhi()); + if (!theta || std::abs(*theta) > TOLERANCE || !phi || + std::abs(*phi) > TOLERANCE) { return failure(); } @@ -63,16 +57,10 @@ struct ReplaceUWithRX final : OpRewritePattern { LogicalResult matchAndRewrite(UOp op, PatternRewriter& rewriter) const override { - const auto phi = UOp::getStaticParameter(op.getPhi()); - const auto lambda = UOp::getStaticParameter(op.getLambda()); - if (!phi || !lambda) { - return failure(); - } - - const auto phiValue = phi.getValueAsDouble(); - const auto lambdaValue = lambda.getValueAsDouble(); - if (std::abs(phiValue + (std::numbers::pi / 2.0)) > TOLERANCE || - std::abs(lambdaValue - (std::numbers::pi / 2.0)) > TOLERANCE) { + const auto phi = valueToDouble(op.getPhi()); + const auto lambda = valueToDouble(op.getLambda()); + if (!phi || std::abs(*phi + (std::numbers::pi / 2.0)) > TOLERANCE || + !lambda || std::abs(*lambda - (std::numbers::pi / 2.0)) > TOLERANCE) { return failure(); } @@ -92,15 +80,10 @@ struct ReplaceUWithRY final : OpRewritePattern { LogicalResult matchAndRewrite(UOp op, PatternRewriter& rewriter) const override { - const auto phi = UOp::getStaticParameter(op.getPhi()); - const auto lambda = UOp::getStaticParameter(op.getLambda()); - if (!phi || !lambda) { - return failure(); - } - - const auto phiValue = phi.getValueAsDouble(); - const auto lambdaValue = lambda.getValueAsDouble(); - if (std::abs(phiValue) > TOLERANCE || std::abs(lambdaValue) > TOLERANCE) { + const auto phi = valueToDouble(op.getPhi()); + const auto lambda = valueToDouble(op.getLambda()); + if (!phi || std::abs(*phi) > TOLERANCE || !lambda || + std::abs(*lambda) > TOLERANCE) { return failure(); } diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp index 366b9b40c..4a45be0bc 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXMinusYYOp.cpp @@ -47,11 +47,10 @@ struct MergeSubsequentXXMinusYY final : OpRewritePattern { } // Confirm betas are equal - auto beta = XXMinusYYOp::getStaticParameter(op.getBeta()); - auto prevBeta = XXMinusYYOp::getStaticParameter(prevOp.getBeta()); + auto beta = valueToDouble(op.getBeta()); + auto prevBeta = valueToDouble(prevOp.getBeta()); if (beta && prevBeta) { - if (std::abs(beta.getValueAsDouble() - prevBeta.getValueAsDouble()) > - TOLERANCE) { + if (std::abs(*beta - *prevBeta) > TOLERANCE) { return failure(); } } else if (op.getBeta() != prevOp.getBeta()) { diff --git a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp index 1d9cf1274..3574f41e9 100644 --- a/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp +++ b/mlir/lib/Dialect/QCO/IR/Operations/StandardGates/XXPlusYYOp.cpp @@ -46,11 +46,10 @@ struct MergeSubsequentXXPlusYY final : OpRewritePattern { } // Confirm betas are equal - auto beta = XXPlusYYOp::getStaticParameter(op.getBeta()); - auto prevBeta = XXPlusYYOp::getStaticParameter(prevOp.getBeta()); + auto beta = valueToDouble(op.getBeta()); + auto prevBeta = valueToDouble(prevOp.getBeta()); if (beta && prevBeta) { - if (std::abs(beta.getValueAsDouble() - prevBeta.getValueAsDouble()) > - TOLERANCE) { + if (std::abs(*beta - *prevBeta) > TOLERANCE) { return failure(); } } else if (op.getBeta() != prevOp.getBeta()) { diff --git a/mlir/unittests/CMakeLists.txt b/mlir/unittests/CMakeLists.txt index bf3838d9e..c7a5e5012 100644 --- a/mlir/unittests/CMakeLists.txt +++ b/mlir/unittests/CMakeLists.txt @@ -12,4 +12,4 @@ add_subdirectory(Dialect) add_custom_target(mqt-core-mlir-unittests) add_dependencies(mqt-core-mlir-unittests mqt-core-mlir-compiler-pipeline-test - mqt-core-mlir-dialect-qco-ir-modifiers-test) + mqt-core-mlir-dialect-qco-ir-modifiers-test mqt-core-mlir-dialect-utils-test) diff --git a/mlir/unittests/Dialect/CMakeLists.txt b/mlir/unittests/Dialect/CMakeLists.txt index 6ced278d0..cc5b8feb9 100644 --- a/mlir/unittests/Dialect/CMakeLists.txt +++ b/mlir/unittests/Dialect/CMakeLists.txt @@ -7,3 +7,4 @@ # Licensed under the MIT License add_subdirectory(QCO) +add_subdirectory(Utils) diff --git a/mlir/unittests/Dialect/Utils/CMakeLists.txt b/mlir/unittests/Dialect/Utils/CMakeLists.txt new file mode 100644 index 000000000..bdc1b435c --- /dev/null +++ b/mlir/unittests/Dialect/Utils/CMakeLists.txt @@ -0,0 +1,14 @@ +# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM +# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH +# All rights reserved. +# +# SPDX-License-Identifier: MIT +# +# Licensed under the MIT License + +add_executable(mqt-core-mlir-dialect-utils-test test_utils.cpp) + +target_link_libraries(mqt-core-mlir-dialect-utils-test PRIVATE GTest::gtest_main MLIRArithDialect + MLIRIR MLIRSupport LLVMSupport) + +gtest_discover_tests(mqt-core-mlir-dialect-utils-test) diff --git a/mlir/unittests/Dialect/Utils/test_utils.cpp b/mlir/unittests/Dialect/Utils/test_utils.cpp new file mode 100644 index 000000000..6a6c75563 --- /dev/null +++ b/mlir/unittests/Dialect/Utils/test_utils.cpp @@ -0,0 +1,132 @@ +/* + * Copyright (c) 2023 - 2026 Chair for Design Automation, TUM + * Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH + * All rights reserved. + * + * SPDX-License-Identifier: MIT + * + * Licensed under the MIT License + */ + +#include "mlir/Dialect/Utils/Utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace mlir; + +class UtilsTest : public ::testing::Test { +protected: + MLIRContext context; + std::unique_ptr builder; + std::unique_ptr loc; + + void SetUp() override { + context.loadDialect(); + + builder = std::make_unique(&context); + loc = std::make_unique(builder->getUnknownLoc()); + } + + arith::AddFOp createAddition(double a, double b) { + auto firstOperand = + builder->create(*loc, builder->getF64FloatAttr(a)); + auto secondOperand = + builder->create(*loc, builder->getF64FloatAttr(b)); + return builder->create(*loc, firstOperand, secondOperand); + } +}; + +TEST_F(UtilsTest, valueToDouble) { + constexpr double expectedValue = 1.234; + auto op = builder->create( + *loc, builder->getF64FloatAttr(expectedValue)); + ASSERT_TRUE(op); + + auto value = op.getResult(); + auto stdValue = utils::valueToDouble(value); + ASSERT_TRUE(stdValue.has_value()); + EXPECT_DOUBLE_EQ(stdValue.value(), expectedValue); +} + +TEST_F(UtilsTest, valueToDoubleCastFromInteger) { + constexpr int expectedValue = 42; + auto op = builder->create( + *loc, builder->getI32IntegerAttr(expectedValue)); + ASSERT_TRUE(op); + + auto value = op.getResult(); + auto stdValue = utils::valueToDouble(value); + ASSERT_TRUE(stdValue.has_value()); + EXPECT_DOUBLE_EQ(stdValue.value(), expectedValue); +} + +TEST_F(UtilsTest, valueToDoubleCastFromNegativeInteger) { + constexpr int expectedValue = -123; + auto op = builder->create( + *loc, builder->getSI32IntegerAttr(expectedValue)); + ASSERT_TRUE(op); + + auto value = op.getResult(); + auto stdValue = utils::valueToDouble(value); + ASSERT_TRUE(stdValue.has_value()); + EXPECT_DOUBLE_EQ(stdValue.value(), expectedValue); +} + +TEST_F(UtilsTest, valueToDoubleCastFromMaxUnsignedInteger) { + constexpr auto expectedValue = std::numeric_limits::max(); + constexpr auto bitCount = 64; + auto op = builder->create( + *loc, builder->getIntegerAttr(builder->getIntegerType(bitCount, false), + llvm::APInt::getMaxValue(bitCount))); + ASSERT_TRUE(op); + + auto value = op.getResult(); + auto stdValue = utils::valueToDouble(value); + ASSERT_TRUE(stdValue.has_value()); + // cast to double will lose precision, but difference to maximum value of + // int64_t is large enough that the check still makes sense + EXPECT_DOUBLE_EQ(stdValue.value(), static_cast(expectedValue)); +} + +TEST_F(UtilsTest, valueToDoubleWrongType) { + auto op = + builder->create(*loc, builder->getStringAttr("test")); + ASSERT_TRUE(op); + + auto value = op.getResult(); + auto stdValue = utils::valueToDouble(value); + EXPECT_FALSE(stdValue.has_value()); +} + +TEST_F(UtilsTest, valueToDoubleNonStaticValue) { + auto op = createAddition(9.5, 21.5); + ASSERT_TRUE(op); + + auto value = op.getResult(); + auto stdValue = utils::valueToDouble(value); + EXPECT_FALSE(stdValue.has_value()); +} + +TEST_F(UtilsTest, valueToDoubleFoldedConstant) { + auto op = createAddition(1.5, 2.0); + ASSERT_TRUE(op); + + llvm::SmallVector tmp; + llvm::SmallVector newConstants; + ASSERT_TRUE(builder->tryFold(op, tmp, &newConstants).succeeded()); + ASSERT_EQ(newConstants.size(), 1); + auto value = newConstants[0]->getResult(0); + auto stdValue = utils::valueToDouble(value); + ASSERT_TRUE(stdValue.has_value()); + EXPECT_DOUBLE_EQ(stdValue.value(), 3.5); +}