Skip to content

Commit 0f32a73

Browse files
rturradodime10
andauthored
Issue 1220 add merge rotation patterns for qml rot and qml crot (#1955)
**Context:** This task has to do with Catalyst's Merge Rotation pass. When this pass was originally implemented, arbitary rotations were part of the list of rotations that could be merged. Only to find later on that, for two arbitrary rotations, the merge was not as easy as for a fixed rotation, e.g., `RX`, where the rotation angles are simply added. At that point, arbitary rotations were removed from the set of mergeable rotations. For an arbitary rotation, more complex mathematical computations have to be performed in order to obtain the angles of the merged gate. These mathematical formulas are well described [here](https://docs.pennylane.ai/en/stable/code/api/pennylane.transforms.single_qubit_fusion.html#derivation). In order to add arbitrary rotations back, issue #1220 was created. This issue is also very well documented. This PR tries to address that issue. **Description of the Change:** 1. Add arbitrary rotations, `Rot` and `CRot` to the list of rotations in MergeRotationsPatterns.cpp. 2. Change the logic of `MergeRotationsRewritePattern::matchAndRewrite` so that it calls `matchAndRewriteFixedRotationOrPhaseShift` or `matchAndRewriteArbitraryRotation`. 3. Implement `matchAndRewriteArbitraryRotation` following the mathematical formulas mentioned before. 4. Add tests for the general cases of `Rot` and `CRot`, and for the special cases of merging arbitrary rotations. Notice that, in order for the new code to compile, I have needed to include a dependency to the MLIR Math dialect, as well as include guards pointing to the Math header file. **Benefits:** Arbitary rotations can now be merged as part of Catalyst's Merge Rotations pass. **Possible Drawbacks:** There are still a few TODOs in the code. One of them, possibly the main one, has to do with the implementation of the problematic scenarios for differentiability, which I haven't addressed yet. **Related GitHub Issues:** #1220 --------- Co-authored-by: David Ittah <[email protected]>
1 parent 53aae7e commit 0f32a73

File tree

6 files changed

+510
-45
lines changed

6 files changed

+510
-45
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@
128128
* `catalyst.accelerate`, `catalyst.debug.callback`, and `catalyst.pure_callback`, `catalyst.debug.print`, and `catalyst.debug.print_memref` now work when capture is enabled.
129129
[(#1902)](https://github.com/PennyLaneAI/catalyst/pull/1902)
130130

131+
* The merge rotation pass in Catalyst (:func:`~.passes.merge_rotations`) now also considers
132+
`qml.Rot` and `qml.CRot`.
133+
[(#1955)](https://github.com/PennyLaneAI/catalyst/pull/1955)
134+
131135
<h3>Documentation 📝</h3>
132136

133137
<h3>Contributors ✍️</h3>

frontend/catalyst/passes/builtin_passes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,8 @@ def merge_rotations(qnode):
343343
:class:`qml.CRZ <pennylane.CRZ>`,
344344
:class:`qml.PhaseShift <pennylane.PhaseShift>`,
345345
:class:`qml.ControlledPhaseShift <pennylane.ControlledPhaseShift>`,
346+
:class:`qml.Rot <pennylane.Rot>`,
347+
:class:`qml.CRot <pennylane.CRot>`,
346348
:class:`qml.MultiRZ <pennylane.MultiRZ>`.
347349
348350

mlir/include/Quantum/Transforms/Passes.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ def RemoveChainedSelfInversePass : Pass<"remove-chained-self-inverse"> {
9595
def MergeRotationsPass : Pass<"merge-rotations"> {
9696
let summary = "Perform merging of chained rotation gates about the same axis.";
9797

98+
let dependentDialects = ["math::MathDialect"];
99+
98100
let constructor = "catalyst::createMergeRotationsPass()";
99101
}
100102

mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp

Lines changed: 257 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -14,25 +14,31 @@
1414

1515
#define DEBUG_TYPE "merge-rotations"
1616

17-
#include "Quantum/IR/QuantumOps.h"
18-
#include "Quantum/Transforms/Patterns.h"
19-
#include "VerifyParentGateAnalysis.hpp"
17+
#include <array>
18+
#include <cassert> // assert
19+
2020
#include "mlir/Dialect/Arith/IR/Arith.h"
21+
#include "mlir/Dialect/Math/IR/Math.h"
2122
#include "llvm/ADT/StringSet.h"
2223
#include "llvm/Support/Debug.h"
2324
#include "llvm/Support/Errc.h"
2425

26+
#include "Quantum/IR/QuantumOps.h"
27+
#include "Quantum/Transforms/Patterns.h"
28+
#include "VerifyParentGateAnalysis.hpp"
29+
2530
using llvm::dbgs;
2631
using namespace mlir;
2732
using namespace catalyst::quantum;
2833

29-
static const mlir::StringSet<> rotationsSet = {"RX", "RY", "RZ", "PhaseShift",
30-
"CRX", "CRY", "CRZ", "ControlledPhaseShift"};
34+
static const mlir::StringSet<> fixedRotationsAndPhaseShiftsSet = {
35+
"RX", "RY", "RZ", "PhaseShift", "CRX", "CRY", "CRZ", "ControlledPhaseShift"};
36+
static const mlir::StringSet<> arbitraryRotationsSet = {"Rot", "CRot"};
3137

3238
namespace {
3339

3440
// convertOpParamsToValues: helper function for extracting CustomOp parameters as mlir::Values
35-
SmallVector<mlir::Value> convertOpParamsToValues(CustomOp &op, mlir::PatternRewriter &rewriter)
41+
SmallVector<mlir::Value> convertOpParamsToValues(CustomOp &op, PatternRewriter &rewriter)
3642
{
3743
SmallVector<mlir::Value> values;
3844
auto params = op.getParams();
@@ -42,59 +48,282 @@ SmallVector<mlir::Value> convertOpParamsToValues(CustomOp &op, mlir::PatternRewr
4248
return values;
4349
}
4450

51+
// getStaticValuesOrNothing: helper function for extracting Rot or CRot parameters as:
52+
// - doubles, in case they are constant
53+
// - std::nullopt, otherwise
54+
std::array<std::optional<double>, 3> getStaticValuesOrNothing(const SmallVector<mlir::Value> values)
55+
{
56+
assert(values.size() == 3 && "found Rot or CRot operation should have exactly 3 parameters");
57+
auto staticValues = std::array<std::optional<double>, 3>{};
58+
for (auto [index, value] : llvm::enumerate(values)) {
59+
if (auto constOp = value.getDefiningOp();
60+
constOp && constOp->hasTrait<OpTrait::ConstantLike>()) {
61+
if (auto floatAttr = constOp->getAttrOfType<FloatAttr>("value")) {
62+
staticValues[index] = floatAttr.getValueAsDouble();
63+
}
64+
}
65+
}
66+
return staticValues;
67+
}
68+
4569
template <typename ParentOpType, typename OpType>
46-
struct MergeRotationsRewritePattern : public mlir::OpRewritePattern<OpType> {
70+
struct MergeRotationsRewritePattern : public OpRewritePattern<OpType> {
4771
// Merge rotation patterns where at least one operand is non-static.
4872
// The result is a non-static CustomOp, as at least one operand is not known at compile time.
49-
using mlir::OpRewritePattern<OpType>::OpRewritePattern;
73+
using OpRewritePattern<OpType>::OpRewritePattern;
5074

51-
mlir::LogicalResult matchAndRewrite(OpType op, mlir::PatternRewriter &rewriter) const override
75+
// Fixed single rotations and phase shifts can be merged just by adding the angle parameters
76+
LogicalResult matchAndRewriteFixedRotationOrPhaseShift(OpType op,
77+
PatternRewriter &rewriter) const
5278
{
53-
LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n");
54-
auto loc = op.getLoc();
55-
StringRef opGateName = op.getGateName();
56-
if (!rotationsSet.contains(opGateName))
57-
return failure();
5879
ValueRange inQubits = op.getInQubits();
59-
auto parentOp = dyn_cast_or_null<ParentOpType>(inQubits[0].getDefiningOp());
60-
61-
VerifyHeterogeneousParentGateAndNameAnalysis<OpType, ParentOpType> vpga(op);
62-
if (!vpga.getVerifierResult()) {
63-
return failure();
64-
}
80+
auto parentOp = llvm::cast<ParentOpType>(inQubits[0].getDefiningOp());
6581

6682
TypeRange outQubitsTypes = op.getOutQubits().getTypes();
6783
TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes();
6884
ValueRange parentInQubits = parentOp.getInQubits();
6985
ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits();
7086
ValueRange parentInCtrlValues = parentOp.getInCtrlValues();
7187

72-
// extract parameters of the op and its parent,
88+
// Extract parameters of the op and its parent,
7389
// promoting the parameters to mlir::Values if necessary
7490
auto parentParams = convertOpParamsToValues(parentOp, rewriter);
7591
auto params = convertOpParamsToValues(op, rewriter);
92+
93+
auto loc = op.getLoc();
7694
SmallVector<mlir::Value> sumParams;
7795
for (auto [param, parentParam] : llvm::zip(params, parentParams)) {
7896
mlir::Value sumParam =
7997
rewriter.create<arith::AddFOp>(loc, parentParam, param).getResult();
8098
sumParams.push_back(sumParam);
81-
};
99+
}
100+
auto mergeOp = rewriter.create<CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams,
101+
parentInQubits, op.getGateName(), false,
102+
parentInCtrlQubits, parentInCtrlValues);
103+
104+
rewriter.replaceOp(op, mergeOp);
105+
rewriter.eraseOp(parentOp);
106+
107+
return success();
108+
}
109+
110+
// Arbitrary single rotations require more complex maths to be merged
111+
LogicalResult matchAndRewriteArbitraryRotation(OpType op, PatternRewriter &rewriter) const
112+
{
113+
ValueRange inQubits = op.getInQubits();
114+
auto parentOp = llvm::cast<ParentOpType>(inQubits[0].getDefiningOp());
115+
116+
TypeRange outQubitsTypes = op.getOutQubits().getTypes();
117+
TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes();
118+
ValueRange parentInQubits = parentOp.getInQubits();
119+
ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits();
120+
ValueRange parentInCtrlValues = parentOp.getInCtrlValues();
121+
122+
// Extract parameters of the op and its parent,
123+
// promoting the parameters to mlir::Values if necessary
124+
auto parentParams = convertOpParamsToValues(parentOp, rewriter);
125+
auto params = convertOpParamsToValues(op, rewriter);
126+
127+
// Parent params are ϕ1, θ1, and ω1
128+
// Params are ϕ2, θ2, and ω2
129+
mlir::Value phi1 = parentParams[0];
130+
mlir::Value theta1 = parentParams[1];
131+
mlir::Value omega1 = parentParams[2];
132+
mlir::Value phi2 = params[0];
133+
mlir::Value theta2 = params[1];
134+
mlir::Value omega2 = params[2];
135+
136+
auto [phi1Opt, theta1Opt, omega1Opt] = getStaticValuesOrNothing(parentParams);
137+
auto [phi2Opt, theta2Opt, omega2Opt] = getStaticValuesOrNothing(params);
138+
139+
mlir::Value phiF;
140+
mlir::Value thetaF;
141+
mlir::Value omegaF;
142+
143+
// TODO: should we use an epsilon for comparing doubles here?
144+
bool omega1IsZero = omega1Opt.has_value() && omega1Opt.value() == 0.0;
145+
bool phi2IsZero = phi2Opt.has_value() && phi2Opt.value() == 0.0;
146+
bool theta1IsZero = theta1Opt.has_value() && theta1Opt.value() == 0.0;
147+
bool theta2IsZero = theta2Opt.has_value() && theta2Opt.value() == 0.0;
148+
149+
auto loc = op.getLoc();
150+
151+
// Special cases:
152+
//
153+
// 1. if (ω1 == 0 && ϕ2 == 0) { ϕF = ϕ1; θF = θ1 + θ2; ωF = ω2; }
154+
// 2a. if (θ1 == 0 && θ2 == 0) { ϕF = ϕ1 + ϕ2 + ω1 + ω2; θF = 0; ωF = 0; }
155+
// 2b. if (θ1 == 0) { ϕF = ϕ1 + ϕ2 + ω1; θF = θ2; ωF = ω2; }
156+
// 2c. if (θ2 == 0) { ϕF = ϕ1; θF = θ1; ωF = ω1 + ω2 + ϕ2; }
157+
auto zeroConst = rewriter.create<arith::ConstantOp>(loc, rewriter.getF64FloatAttr(0.0));
158+
if (omega1IsZero && phi2IsZero) {
159+
phiF = phi1;
160+
thetaF = rewriter.create<arith::AddFOp>(loc, theta1, theta2);
161+
omegaF = omega2;
162+
}
163+
else if (theta1IsZero && theta2IsZero) {
164+
phiF =
165+
rewriter.create<arith::AddFOp>(loc, rewriter.create<arith::AddFOp>(loc, phi1, phi2),
166+
rewriter.create<arith::AddFOp>(loc, omega1, omega2));
167+
thetaF = zeroConst;
168+
omegaF = zeroConst;
169+
}
170+
else if (theta1IsZero) {
171+
phiF = rewriter.create<arith::AddFOp>(
172+
loc, rewriter.create<arith::AddFOp>(loc, phi1, phi2), omega1);
173+
thetaF = theta2;
174+
omegaF = omega2;
175+
}
176+
else if (theta2IsZero) {
177+
phiF = phi1;
178+
thetaF = theta1;
179+
omegaF = rewriter.create<arith::AddFOp>(
180+
loc, rewriter.create<arith::AddFOp>(loc, omega1, omega2), phi2);
181+
}
182+
else {
183+
auto halfConst = rewriter.create<arith::ConstantOp>(loc, rewriter.getF64FloatAttr(0.5));
184+
auto twoConst = rewriter.create<arith::ConstantOp>(loc, rewriter.getF64FloatAttr(2.0));
185+
186+
// α1 = (ϕ1 + ω1)/2, α2 = (ϕ2 + ω2)/2
187+
// β1 = (ϕ1 - ω1)/2, β2 = (ϕ2 - ω2)/2
188+
auto alpha1 = rewriter.create<arith::MulFOp>(
189+
loc, rewriter.create<arith::AddFOp>(loc, phi1, omega1), halfConst);
190+
auto alpha2 = rewriter.create<arith::MulFOp>(
191+
loc, rewriter.create<arith::AddFOp>(loc, phi2, omega2), halfConst);
192+
auto beta1 = rewriter.create<arith::MulFOp>(
193+
loc, rewriter.create<arith::SubFOp>(loc, phi1, omega1), halfConst);
194+
auto beta2 = rewriter.create<arith::MulFOp>(
195+
loc, rewriter.create<arith::SubFOp>(loc, phi2, omega2), halfConst);
196+
197+
// c1 = cos(θ1/2), c2 = cos(θ2/2)
198+
// s1 = sin(θ1/2), s2 = sin(θ2/2)
199+
auto theta1Half = rewriter.create<arith::MulFOp>(loc, theta1, halfConst);
200+
auto c1 = rewriter.create<math::CosOp>(loc, theta1Half);
201+
auto s1 = rewriter.create<math::SinOp>(loc, theta1Half);
202+
auto theta2Half = rewriter.create<arith::MulFOp>(loc, theta2, halfConst);
203+
auto c2 = rewriter.create<math::CosOp>(loc, theta2Half);
204+
auto s2 = rewriter.create<math::SinOp>(loc, theta2Half);
205+
206+
// cF = sqrt(c1^2 * c2^2 +
207+
// s1^2 * s2^2 -
208+
// 2 * c1 * c2 * s1 * s2 * cos(ω1 + ϕ2))
209+
auto c1TimesC2 = rewriter.create<arith::MulFOp>(loc, c1, c2);
210+
auto s1TimesS2 = rewriter.create<arith::MulFOp>(loc, s1, s2);
211+
auto firstAddend =
212+
rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, c1, c1),
213+
rewriter.create<arith::MulFOp>(loc, c2, c2));
214+
auto secondAddend =
215+
rewriter.create<arith::MulFOp>(loc, rewriter.create<arith::MulFOp>(loc, s1, s1),
216+
rewriter.create<arith::MulFOp>(loc, s2, s2));
217+
auto thirdAddend = rewriter.create<arith::NegFOp>(
218+
loc, rewriter.create<arith::MulFOp>(
219+
loc, twoConst,
220+
rewriter.create<arith::MulFOp>(
221+
loc, c1TimesC2,
222+
rewriter.create<arith::MulFOp>(
223+
loc, s1TimesS2,
224+
rewriter.create<math::CosOp>(
225+
loc, rewriter.create<arith::AddFOp>(loc, omega1, phi2))))));
226+
auto cF = rewriter.create<math::SqrtOp>(
227+
loc, rewriter.create<arith::AddFOp>(
228+
loc, firstAddend,
229+
rewriter.create<arith::AddFOp>(loc, secondAddend, thirdAddend)));
230+
231+
// TODO: can we check these problematic scenarios for differentiability by code?
232+
// Problematic scenarios for differentiability:
233+
//
234+
// 1. if (cF == 0) { /* sqrt not differentiable at 0 */ return failure(); }
235+
// 2. if (cF == 1) { /* acos not differentiable at 1 */ return failure(); }
236+
237+
// θF = 2 * acos(cF)
238+
auto acosCF = rewriter.create<math::AcosOp>(loc, cF);
239+
thetaF = rewriter.create<arith::MulFOp>(loc, twoConst, acosCF);
240+
241+
// αF = - atan((- c1 * c2 * sin(α1 + α2) - s1 * s2 * sin(β2 - β1)) /
242+
// ( c1 * c2 * cos(α1 + α2) - s1 * s2 * cos(β2 - β1)))
243+
auto alpha1PlusAlpha2 = rewriter.create<arith::AddFOp>(loc, alpha1, alpha2);
244+
auto beta2MinusBeta1 = rewriter.create<arith::SubFOp>(loc, beta2, beta1);
245+
auto term1 = rewriter.create<arith::NegFOp>(
246+
loc, rewriter.create<arith::MulFOp>(
247+
loc, c1TimesC2, rewriter.create<math::SinOp>(loc, alpha1PlusAlpha2)));
248+
auto term2 = rewriter.create<arith::NegFOp>(
249+
loc, rewriter.create<arith::MulFOp>(
250+
loc, s1TimesS2, rewriter.create<math::SinOp>(loc, beta2MinusBeta1)));
251+
auto term3 = rewriter.create<arith::MulFOp>(
252+
loc, c1TimesC2, rewriter.create<math::CosOp>(loc, alpha1PlusAlpha2));
253+
auto term4 = rewriter.create<arith::NegFOp>(
254+
loc, rewriter.create<arith::MulFOp>(
255+
loc, s1TimesS2, rewriter.create<math::CosOp>(loc, beta2MinusBeta1)));
256+
auto alphaF = rewriter.create<arith::NegFOp>(
257+
loc, rewriter.create<math::AtanOp>(
258+
loc, rewriter.create<arith::DivFOp>(
259+
loc, rewriter.create<arith::AddFOp>(loc, term1, term2),
260+
rewriter.create<arith::AddFOp>(loc, term3, term4))));
261+
262+
// βF = - atan((- c1 * s2 * sin(α1 + β2) + s1 * c2 * sin(α2 - β1)) /
263+
// ( c1 * s2 * cos(α1 + β2) + s1 * c2 * cos(α2 - β1)))
264+
auto c1TimesS2 = rewriter.create<arith::MulFOp>(loc, c1, s2);
265+
auto s1TimesC2 = rewriter.create<arith::MulFOp>(loc, s1, c2);
266+
auto alpha1PlusBeta2 = rewriter.create<arith::AddFOp>(loc, alpha1, beta2);
267+
auto alpha2MinusBeta1 = rewriter.create<arith::SubFOp>(loc, alpha2, beta1);
268+
auto term5 = rewriter.create<arith::NegFOp>(
269+
loc, rewriter.create<arith::MulFOp>(
270+
loc, c1TimesS2, rewriter.create<math::SinOp>(loc, alpha1PlusBeta2)));
271+
auto term6 = rewriter.create<arith::MulFOp>(
272+
loc, s1TimesC2, rewriter.create<math::SinOp>(loc, alpha2MinusBeta1));
273+
auto term7 = rewriter.create<arith::MulFOp>(
274+
loc, c1TimesS2, rewriter.create<math::CosOp>(loc, alpha1PlusBeta2));
275+
auto term8 = rewriter.create<arith::MulFOp>(
276+
loc, s1TimesC2, rewriter.create<math::CosOp>(loc, alpha2MinusBeta1));
277+
auto betaF = rewriter.create<arith::NegFOp>(
278+
loc, rewriter.create<math::AtanOp>(
279+
loc, rewriter.create<arith::DivFOp>(
280+
loc, rewriter.create<arith::AddFOp>(loc, term5, term6),
281+
rewriter.create<arith::AddFOp>(loc, term7, term8))));
282+
283+
// ϕF = αF + βF
284+
phiF = rewriter.create<arith::AddFOp>(loc, alphaF, betaF);
285+
286+
// ωF = αF - βF
287+
omegaF = rewriter.create<arith::SubFOp>(loc, alphaF, betaF);
288+
}
289+
290+
auto sumParams = SmallVector<mlir::Value>{phiF, thetaF, omegaF};
82291
auto mergeOp = rewriter.create<CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams,
83-
parentInQubits, opGateName, false,
292+
parentInQubits, op.getGateName(), false,
84293
parentInCtrlQubits, parentInCtrlValues);
85294

86295
rewriter.replaceOp(op, mergeOp);
87296
rewriter.eraseOp(parentOp);
88297

89298
return success();
90299
}
300+
301+
LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override
302+
{
303+
LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n");
304+
305+
StringRef opGateName = op.getGateName();
306+
if (!fixedRotationsAndPhaseShiftsSet.contains(opGateName) &&
307+
!arbitraryRotationsSet.contains(opGateName)) {
308+
return failure();
309+
}
310+
311+
VerifyHeterogeneousParentGateAndNameAnalysis<OpType, ParentOpType> vpga(op);
312+
if (!vpga.getVerifierResult()) {
313+
return failure();
314+
}
315+
316+
if (fixedRotationsAndPhaseShiftsSet.contains(opGateName)) {
317+
return matchAndRewriteFixedRotationOrPhaseShift(op, rewriter);
318+
}
319+
return matchAndRewriteArbitraryRotation(op, rewriter);
320+
}
91321
};
92322

93-
struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern<MultiRZOp> {
94-
using mlir::OpRewritePattern<MultiRZOp>::OpRewritePattern;
323+
struct MergeMultiRZRewritePattern : public OpRewritePattern<MultiRZOp> {
324+
using OpRewritePattern<MultiRZOp>::OpRewritePattern;
95325

96-
mlir::LogicalResult matchAndRewrite(MultiRZOp op,
97-
mlir::PatternRewriter &rewriter) const override
326+
LogicalResult matchAndRewrite(MultiRZOp op, PatternRewriter &rewriter) const override
98327
{
99328
LLVM_DEBUG(dbgs() << "Simplifying the following operation:\n" << op << "\n");
100329
auto loc = op.getLoc();
@@ -105,9 +334,7 @@ struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern<MultiRZOp> {
105334
}
106335

107336
ValueRange inQubits = op.getInQubits();
108-
auto parentOp = dyn_cast_or_null<MultiRZOp>(inQubits[0].getDefiningOp());
109-
if (!parentOp)
110-
return failure();
337+
auto parentOp = llvm::cast<MultiRZOp>(inQubits[0].getDefiningOp());
111338

112339
TypeRange outQubitsTypes = op.getOutQubits().getTypes();
113340
TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits().getTypes();

mlir/lib/Quantum/Transforms/merge_rotation.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "llvm/Support/Debug.h"
1818

1919
#include "mlir/Dialect/Func/IR/FuncOps.h"
20+
#include "mlir/Dialect/Math/IR/Math.h"
2021
#include "mlir/Dialect/SCF/IR/SCF.h"
2122
#include "mlir/Pass/Pass.h"
2223
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

0 commit comments

Comments
 (0)