Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ This project adheres to [Semantic Versioning], with the exception that minor rel

### Added

- ✨ Add initial infrastructure for new QC and QCO MLIR dialects ([#1264], [#1402], [#1428], [#1430], [#1436]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**])
- ✨ Add initial infrastructure for new QC and QCO MLIR dialects ([#1264], [#1402], [#1428], [#1430], [#1436], [#1443]) ([**@burgholzer**], [**@denialhaag**], [**@taminob**], [**@DRovara**])

### Changed

Expand Down Expand Up @@ -307,6 +307,7 @@ _📚 Refer to the [GitHub Release Notes](https://github.com/munich-quantum-tool

<!-- PR links -->

[#1443]: https://github.com/munich-quantum-toolkit/core/pull/1443
[#1437]: https://github.com/munich-quantum-toolkit/core/pull/1437
[#1436]: https://github.com/munich-quantum-toolkit/core/pull/1436
[#1430]: https://github.com/munich-quantum-toolkit/core/pull/1430
Expand Down
8 changes: 0 additions & 8 deletions mlir/include/mlir/Dialect/QC/IR/QCDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,14 +101,6 @@ template <size_t T, size_t P> class TargetAndParameterArityTrait {
}
return this->getOperation()->getOperand(T + i);
}

[[nodiscard]] static FloatAttr getStaticParameter(Value param) {
auto constantOp = param.getDefiningOp<arith::ConstantOp>();
if (!constantOp) {
return nullptr;
}
return dyn_cast<FloatAttr>(constantOp.getValue());
}
};
};

Expand Down
8 changes: 0 additions & 8 deletions mlir/include/mlir/Dialect/QCO/IR/QCODialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,6 @@ template <size_t T, size_t P> class TargetAndParameterArityTrait {
return this->getOperation()->getOperand(T + i);
}

[[nodiscard]] static FloatAttr getStaticParameter(Value param) {
auto constantOp = param.getDefiningOp<arith::ConstantOp>();
if (!constantOp) {
return nullptr;
}
return dyn_cast<FloatAttr>(constantOp.getValue());
}

Value getInputForOutput(Value output) {
const auto& op = this->getOperation();
for (size_t i = 0; i < T; ++i) {
Expand Down
21 changes: 6 additions & 15 deletions mlir/include/mlir/Dialect/QCO/QCOUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@

#pragma once

#include "mlir/Dialect/Utils/Utils.h"

#include <mlir/Dialect/Arith/IR/Arith.h>
#include <mlir/Dialect/Utils/Utils.h>
#include <mlir/IR/PatternMatch.h>

namespace mlir::qco {
Expand Down Expand Up @@ -188,13 +189,8 @@ mergeTwoTargetOneParameter(OpType op, mlir::PatternRewriter& rewriter) {
template <typename OpType>
inline mlir::LogicalResult
removeTrivialOneTargetOneParameter(OpType op, mlir::PatternRewriter& rewriter) {
const auto paramAttr = OpType::getStaticParameter(op.getOperand(1));
if (!paramAttr) {
return failure();
}

const auto paramValue = paramAttr.getValueAsDouble();
if (std::abs(paramValue) > utils::TOLERANCE) {
const auto param = utils::valueToDouble(op.getOperand(1));
if (!param || std::abs(*param) > utils::TOLERANCE) {
return failure();
}

Expand All @@ -215,13 +211,8 @@ removeTrivialOneTargetOneParameter(OpType op, mlir::PatternRewriter& rewriter) {
template <typename OpType>
inline mlir::LogicalResult
removeTrivialTwoTargetOneParameter(OpType op, mlir::PatternRewriter& rewriter) {
const auto paramAttr = OpType::getStaticParameter(op.getOperand(2));
if (!paramAttr) {
return failure();
}

const auto paramValue = paramAttr.getValueAsDouble();
if (std::abs(paramValue) > utils::TOLERANCE) {
const auto param = utils::valueToDouble(op.getOperand(2));
if (!param || std::abs(*param) > utils::TOLERANCE) {
return failure();
}

Expand Down
33 changes: 31 additions & 2 deletions mlir/include/mlir/Dialect/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,9 @@ constexpr auto TOLERANCE = 1e-15;
* @param parameter The parameter as a variant (double or Value).
* @return Value The parameter as a Value.
*/
inline Value variantToValue(OpBuilder& builder, const OperationState& state,
const std::variant<double, Value>& parameter) {
[[nodiscard]] inline Value
variantToValue(OpBuilder& builder, const OperationState& state,
const std::variant<double, Value>& parameter) {
Value operand;
if (std::holds_alternative<double>(parameter)) {
operand = builder.create<arith::ConstantOp>(
Expand All @@ -40,4 +41,32 @@ inline Value variantToValue(OpBuilder& builder, const OperationState& state,
return operand;
}

/**
* @brief Try to convert a mlir::Value to a standard C++ double
*
* @details
* Resolving the mlir::Value will only work if it is a static value, so a value
* defined via a "arith.constant" operation. It must also be of type
* float or integer.
*/
[[nodiscard]] inline std::optional<double> valueToDouble(Value value) {
auto constantOp = value.getDefiningOp<arith::ConstantOp>();
if (!constantOp) {
return std::nullopt;
}
auto floatAttr = dyn_cast<FloatAttr>(constantOp.getValue());
if (floatAttr) {
return floatAttr.getValueAsDouble();
}
auto intAttr = dyn_cast<IntegerAttr>(constantOp.getValue());
if (intAttr) {
if (intAttr.getType().isUnsignedInteger()) {
return static_cast<double>(intAttr.getValue().getZExtValue());
}
// interpret both signed+signless as signed integers
return static_cast<double>(intAttr.getValue().getSExtValue());
}
return std::nullopt;
}

} // namespace mlir::utils
10 changes: 2 additions & 8 deletions mlir/lib/Dialect/QCO/IR/Operations/StandardGates/GPhaseOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

#include <cmath>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/IR/PatternMatch.h>
Expand All @@ -34,13 +33,8 @@ struct RemoveTrivialGPhase final : OpRewritePattern<GPhaseOp> {

LogicalResult matchAndRewrite(GPhaseOp op,
PatternRewriter& rewriter) const override {
const auto thetaAttr = GPhaseOp::getStaticParameter(op.getTheta());
if (!thetaAttr) {
return failure();
}

const auto thetaValue = thetaAttr.getValueAsDouble();
if (std::abs(thetaValue) > TOLERANCE) {
const auto theta = valueToDouble(op.getTheta());
if (!theta || std::abs(*theta) > TOLERANCE) {
return failure();
}

Expand Down
19 changes: 4 additions & 15 deletions mlir/lib/Dialect/QCO/IR/Operations/StandardGates/ROp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

#include <cmath>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/IR/PatternMatch.h>
Expand All @@ -35,13 +34,8 @@ struct ReplaceRWithRX final : OpRewritePattern<ROp> {

LogicalResult matchAndRewrite(ROp op,
PatternRewriter& rewriter) const override {
const auto phi = ROp::getStaticParameter(op.getPhi());
if (!phi) {
return failure();
}

const auto phiValue = phi.getValueAsDouble();
if (std::abs(phiValue) > TOLERANCE) {
const auto phi = valueToDouble(op.getPhi());
if (!phi || std::abs(*phi) > TOLERANCE) {
return failure();
}

Expand All @@ -61,13 +55,8 @@ struct ReplaceRWithRY final : OpRewritePattern<ROp> {

LogicalResult matchAndRewrite(ROp op,
PatternRewriter& rewriter) const override {
const auto phi = ROp::getStaticParameter(op.getPhi());
if (!phi) {
return failure();
}

const auto phiValue = phi.getValueAsDouble();
if (std::abs(phiValue - (std::numbers::pi / 2.0)) > TOLERANCE) {
const auto phi = valueToDouble(op.getPhi());
if (!phi || std::abs(*phi - (std::numbers::pi / 2.0)) > TOLERANCE) {
return failure();
}

Expand Down
41 changes: 12 additions & 29 deletions mlir/lib/Dialect/QCO/IR/Operations/StandardGates/U2Op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,10 @@ struct ReplaceU2WithH final : OpRewritePattern<U2Op> {

LogicalResult matchAndRewrite(U2Op op,
PatternRewriter& rewriter) const override {
const auto phi = U2Op::getStaticParameter(op.getPhi());
const auto lambda = U2Op::getStaticParameter(op.getLambda());
if (!phi || !lambda) {
return failure();
}

const auto phiValue = phi.getValueAsDouble();
const auto lambdaValue = lambda.getValueAsDouble();
if (std::abs(phiValue) > TOLERANCE ||
std::abs(lambdaValue - std::numbers::pi) > TOLERANCE) {
const auto phi = valueToDouble(op.getPhi());
const auto lambda = valueToDouble(op.getLambda());
if (!phi || std::abs(*phi) > TOLERANCE || !lambda ||
std::abs(*lambda - std::numbers::pi) > TOLERANCE) {
return failure();
}

Expand All @@ -62,16 +56,10 @@ struct ReplaceU2WithRX final : OpRewritePattern<U2Op> {

LogicalResult matchAndRewrite(U2Op op,
PatternRewriter& rewriter) const override {
const auto phi = U2Op::getStaticParameter(op.getPhi());
const auto lambda = U2Op::getStaticParameter(op.getLambda());
if (!phi || !lambda) {
return failure();
}

const auto phiValue = phi.getValueAsDouble();
const auto lambdaValue = lambda.getValueAsDouble();
if (std::abs(phiValue + (std::numbers::pi / 2.0)) > TOLERANCE ||
std::abs(lambdaValue - (std::numbers::pi / 2.0)) > TOLERANCE) {
const auto phi = valueToDouble(op.getPhi());
const auto lambda = valueToDouble(op.getLambda());
if (!phi || std::abs(*phi + (std::numbers::pi / 2.0)) > TOLERANCE ||
!lambda || std::abs(*lambda - (std::numbers::pi / 2.0)) > TOLERANCE) {
return failure();
}

Expand All @@ -91,15 +79,10 @@ struct ReplaceU2WithRY final : OpRewritePattern<U2Op> {

LogicalResult matchAndRewrite(U2Op op,
PatternRewriter& rewriter) const override {
const auto phi = U2Op::getStaticParameter(op.getPhi());
const auto lambda = U2Op::getStaticParameter(op.getLambda());
if (!phi || !lambda) {
return failure();
}

const auto phiValue = phi.getValueAsDouble();
const auto lambdaValue = lambda.getValueAsDouble();
if (std::abs(phiValue) > TOLERANCE || std::abs(lambdaValue) > TOLERANCE) {
const auto phi = valueToDouble(op.getPhi());
const auto lambda = valueToDouble(op.getLambda());
if (!phi || std::abs(*phi) > TOLERANCE || !lambda ||
std::abs(*lambda) > TOLERANCE) {
return failure();
}

Expand Down
41 changes: 12 additions & 29 deletions mlir/lib/Dialect/QCO/IR/Operations/StandardGates/UOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

#include <cmath>
#include <mlir/IR/Builders.h>
#include <mlir/IR/BuiltinAttributes.h>
#include <mlir/IR/MLIRContext.h>
#include <mlir/IR/OperationSupport.h>
#include <mlir/IR/PatternMatch.h>
Expand All @@ -35,15 +34,10 @@ struct ReplaceUWithP final : OpRewritePattern<UOp> {

LogicalResult matchAndRewrite(UOp op,
PatternRewriter& rewriter) const override {
const auto theta = UOp::getStaticParameter(op.getTheta());
const auto phi = UOp::getStaticParameter(op.getPhi());
if (!theta || !phi) {
return failure();
}

const auto thetaValue = theta.getValueAsDouble();
const auto phiValue = phi.getValueAsDouble();
if (std::abs(thetaValue) > TOLERANCE || std::abs(phiValue) > TOLERANCE) {
const auto theta = valueToDouble(op.getTheta());
const auto phi = valueToDouble(op.getPhi());
if (!theta || std::abs(*theta) > TOLERANCE || !phi ||
std::abs(*phi) > TOLERANCE) {
return failure();
}

Expand All @@ -63,16 +57,10 @@ struct ReplaceUWithRX final : OpRewritePattern<UOp> {

LogicalResult matchAndRewrite(UOp op,
PatternRewriter& rewriter) const override {
const auto phi = UOp::getStaticParameter(op.getPhi());
const auto lambda = UOp::getStaticParameter(op.getLambda());
if (!phi || !lambda) {
return failure();
}

const auto phiValue = phi.getValueAsDouble();
const auto lambdaValue = lambda.getValueAsDouble();
if (std::abs(phiValue + (std::numbers::pi / 2.0)) > TOLERANCE ||
std::abs(lambdaValue - (std::numbers::pi / 2.0)) > TOLERANCE) {
const auto phi = valueToDouble(op.getPhi());
const auto lambda = valueToDouble(op.getLambda());
if (!phi || std::abs(*phi + (std::numbers::pi / 2.0)) > TOLERANCE ||
!lambda || std::abs(*lambda - (std::numbers::pi / 2.0)) > TOLERANCE) {
return failure();
}

Expand All @@ -92,15 +80,10 @@ struct ReplaceUWithRY final : OpRewritePattern<UOp> {

LogicalResult matchAndRewrite(UOp op,
PatternRewriter& rewriter) const override {
const auto phi = UOp::getStaticParameter(op.getPhi());
const auto lambda = UOp::getStaticParameter(op.getLambda());
if (!phi || !lambda) {
return failure();
}

const auto phiValue = phi.getValueAsDouble();
const auto lambdaValue = lambda.getValueAsDouble();
if (std::abs(phiValue) > TOLERANCE || std::abs(lambdaValue) > TOLERANCE) {
const auto phi = valueToDouble(op.getPhi());
const auto lambda = valueToDouble(op.getLambda());
if (!phi || std::abs(*phi) > TOLERANCE || !lambda ||
std::abs(*lambda) > TOLERANCE) {
return failure();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ struct MergeSubsequentXXMinusYY final : OpRewritePattern<XXMinusYYOp> {
}

// Confirm betas are equal
auto beta = XXMinusYYOp::getStaticParameter(op.getBeta());
auto prevBeta = XXMinusYYOp::getStaticParameter(prevOp.getBeta());
auto beta = valueToDouble(op.getBeta());
auto prevBeta = valueToDouble(prevOp.getBeta());
if (beta && prevBeta) {
if (std::abs(beta.getValueAsDouble() - prevBeta.getValueAsDouble()) >
TOLERANCE) {
if (std::abs(*beta - *prevBeta) > TOLERANCE) {
return failure();
}
} else if (op.getBeta() != prevOp.getBeta()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ struct MergeSubsequentXXPlusYY final : OpRewritePattern<XXPlusYYOp> {
}

// Confirm betas are equal
auto beta = XXPlusYYOp::getStaticParameter(op.getBeta());
auto prevBeta = XXPlusYYOp::getStaticParameter(prevOp.getBeta());
auto beta = valueToDouble(op.getBeta());
auto prevBeta = valueToDouble(prevOp.getBeta());
if (beta && prevBeta) {
if (std::abs(beta.getValueAsDouble() - prevBeta.getValueAsDouble()) >
TOLERANCE) {
if (std::abs(*beta - *prevBeta) > TOLERANCE) {
return failure();
}
} else if (op.getBeta() != prevOp.getBeta()) {
Expand Down
2 changes: 1 addition & 1 deletion mlir/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ add_subdirectory(Dialect)
add_custom_target(mqt-core-mlir-unittests)

add_dependencies(mqt-core-mlir-unittests mqt-core-mlir-compiler-pipeline-test
mqt-core-mlir-dialect-qco-ir-modifiers-test)
mqt-core-mlir-dialect-qco-ir-modifiers-test mqt-core-mlir-dialect-utils-test)
1 change: 1 addition & 0 deletions mlir/unittests/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
# Licensed under the MIT License

add_subdirectory(QCO)
add_subdirectory(Utils)
14 changes: 14 additions & 0 deletions mlir/unittests/Dialect/Utils/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright (c) 2023 - 2026 Chair for Design Automation, TUM
# Copyright (c) 2025 - 2026 Munich Quantum Software Company GmbH
# All rights reserved.
#
# SPDX-License-Identifier: MIT
#
# Licensed under the MIT License

add_executable(mqt-core-mlir-dialect-utils-test test_utils.cpp)

target_link_libraries(mqt-core-mlir-dialect-utils-test PRIVATE GTest::gtest_main MLIRArithDialect
MLIRIR MLIRSupport LLVMSupport)

gtest_discover_tests(mqt-core-mlir-dialect-utils-test)
Loading
Loading