1414
1515#define DEBUG_TYPE " merge-rotations"
1616
17- #include " Quantum/IR/QuantumOps.h "
18- #include " Quantum/Transforms/Patterns.h "
19- # include " VerifyParentGateAnalysis.hpp "
17+ #include < array >
18+ #include < cassert > // assert
19+
2020#include " mlir/Dialect/Arith/IR/Arith.h"
21+ #include " mlir/Dialect/Math/IR/Math.h"
2122#include " llvm/ADT/StringSet.h"
2223#include " llvm/Support/Debug.h"
2324#include " llvm/Support/Errc.h"
2425
26+ #include " Quantum/IR/QuantumOps.h"
27+ #include " Quantum/Transforms/Patterns.h"
28+ #include " VerifyParentGateAnalysis.hpp"
29+
2530using llvm::dbgs;
2631using namespace mlir ;
2732using namespace catalyst ::quantum;
2833
29- static const mlir::StringSet<> rotationsSet = {" RX" , " RY" , " RZ" , " PhaseShift" ,
30- " CRX" , " CRY" , " CRZ" , " ControlledPhaseShift" };
34+ static const mlir::StringSet<> fixedRotationsAndPhaseShiftsSet = {
35+ " RX" , " RY" , " RZ" , " PhaseShift" , " CRX" , " CRY" , " CRZ" , " ControlledPhaseShift" };
36+ static const mlir::StringSet<> arbitraryRotationsSet = {" Rot" , " CRot" };
3137
3238namespace {
3339
3440// convertOpParamsToValues: helper function for extracting CustomOp parameters as mlir::Values
35- SmallVector<mlir::Value> convertOpParamsToValues (CustomOp &op, mlir:: PatternRewriter &rewriter)
41+ SmallVector<mlir::Value> convertOpParamsToValues (CustomOp &op, PatternRewriter &rewriter)
3642{
3743 SmallVector<mlir::Value> values;
3844 auto params = op.getParams ();
@@ -42,59 +48,282 @@ SmallVector<mlir::Value> convertOpParamsToValues(CustomOp &op, mlir::PatternRewr
4248 return values;
4349}
4450
51+ // getStaticValuesOrNothing: helper function for extracting Rot or CRot parameters as:
52+ // - doubles, in case they are constant
53+ // - std::nullopt, otherwise
54+ std::array<std::optional<double >, 3 > getStaticValuesOrNothing (const SmallVector<mlir::Value> values)
55+ {
56+ assert (values.size () == 3 && " found Rot or CRot operation should have exactly 3 parameters" );
57+ auto staticValues = std::array<std::optional<double >, 3 >{};
58+ for (auto [index, value] : llvm::enumerate (values)) {
59+ if (auto constOp = value.getDefiningOp ();
60+ constOp && constOp->hasTrait <OpTrait::ConstantLike>()) {
61+ if (auto floatAttr = constOp->getAttrOfType <FloatAttr>(" value" )) {
62+ staticValues[index] = floatAttr.getValueAsDouble ();
63+ }
64+ }
65+ }
66+ return staticValues;
67+ }
68+
4569template <typename ParentOpType, typename OpType>
46- struct MergeRotationsRewritePattern : public mlir :: OpRewritePattern<OpType> {
70+ struct MergeRotationsRewritePattern : public OpRewritePattern <OpType> {
4771 // Merge rotation patterns where at least one operand is non-static.
4872 // The result is a non-static CustomOp, as at least one operand is not known at compile time.
49- using mlir:: OpRewritePattern<OpType>::OpRewritePattern;
73+ using OpRewritePattern<OpType>::OpRewritePattern;
5074
51- mlir::LogicalResult matchAndRewrite (OpType op, mlir::PatternRewriter &rewriter) const override
75+ // Fixed single rotations and phase shifts can be merged just by adding the angle parameters
76+ LogicalResult matchAndRewriteFixedRotationOrPhaseShift (OpType op,
77+ PatternRewriter &rewriter) const
5278 {
53- LLVM_DEBUG (dbgs () << " Simplifying the following operation:\n " << op << " \n " );
54- auto loc = op.getLoc ();
55- StringRef opGateName = op.getGateName ();
56- if (!rotationsSet.contains (opGateName))
57- return failure ();
5879 ValueRange inQubits = op.getInQubits ();
59- auto parentOp = dyn_cast_or_null<ParentOpType>(inQubits[0 ].getDefiningOp ());
60-
61- VerifyHeterogeneousParentGateAndNameAnalysis<OpType, ParentOpType> vpga (op);
62- if (!vpga.getVerifierResult ()) {
63- return failure ();
64- }
80+ auto parentOp = llvm::cast<ParentOpType>(inQubits[0 ].getDefiningOp ());
6581
6682 TypeRange outQubitsTypes = op.getOutQubits ().getTypes ();
6783 TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits ().getTypes ();
6884 ValueRange parentInQubits = parentOp.getInQubits ();
6985 ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits ();
7086 ValueRange parentInCtrlValues = parentOp.getInCtrlValues ();
7187
72- // extract parameters of the op and its parent,
88+ // Extract parameters of the op and its parent,
7389 // promoting the parameters to mlir::Values if necessary
7490 auto parentParams = convertOpParamsToValues (parentOp, rewriter);
7591 auto params = convertOpParamsToValues (op, rewriter);
92+
93+ auto loc = op.getLoc ();
7694 SmallVector<mlir::Value> sumParams;
7795 for (auto [param, parentParam] : llvm::zip (params, parentParams)) {
7896 mlir::Value sumParam =
7997 rewriter.create <arith::AddFOp>(loc, parentParam, param).getResult ();
8098 sumParams.push_back (sumParam);
81- };
99+ }
100+ auto mergeOp = rewriter.create <CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams,
101+ parentInQubits, op.getGateName (), false ,
102+ parentInCtrlQubits, parentInCtrlValues);
103+
104+ rewriter.replaceOp (op, mergeOp);
105+ rewriter.eraseOp (parentOp);
106+
107+ return success ();
108+ }
109+
110+ // Arbitrary single rotations require more complex maths to be merged
111+ LogicalResult matchAndRewriteArbitraryRotation (OpType op, PatternRewriter &rewriter) const
112+ {
113+ ValueRange inQubits = op.getInQubits ();
114+ auto parentOp = llvm::cast<ParentOpType>(inQubits[0 ].getDefiningOp ());
115+
116+ TypeRange outQubitsTypes = op.getOutQubits ().getTypes ();
117+ TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits ().getTypes ();
118+ ValueRange parentInQubits = parentOp.getInQubits ();
119+ ValueRange parentInCtrlQubits = parentOp.getInCtrlQubits ();
120+ ValueRange parentInCtrlValues = parentOp.getInCtrlValues ();
121+
122+ // Extract parameters of the op and its parent,
123+ // promoting the parameters to mlir::Values if necessary
124+ auto parentParams = convertOpParamsToValues (parentOp, rewriter);
125+ auto params = convertOpParamsToValues (op, rewriter);
126+
127+ // Parent params are ϕ1, θ1, and ω1
128+ // Params are ϕ2, θ2, and ω2
129+ mlir::Value phi1 = parentParams[0 ];
130+ mlir::Value theta1 = parentParams[1 ];
131+ mlir::Value omega1 = parentParams[2 ];
132+ mlir::Value phi2 = params[0 ];
133+ mlir::Value theta2 = params[1 ];
134+ mlir::Value omega2 = params[2 ];
135+
136+ auto [phi1Opt, theta1Opt, omega1Opt] = getStaticValuesOrNothing (parentParams);
137+ auto [phi2Opt, theta2Opt, omega2Opt] = getStaticValuesOrNothing (params);
138+
139+ mlir::Value phiF;
140+ mlir::Value thetaF;
141+ mlir::Value omegaF;
142+
143+ // TODO: should we use an epsilon for comparing doubles here?
144+ bool omega1IsZero = omega1Opt.has_value () && omega1Opt.value () == 0.0 ;
145+ bool phi2IsZero = phi2Opt.has_value () && phi2Opt.value () == 0.0 ;
146+ bool theta1IsZero = theta1Opt.has_value () && theta1Opt.value () == 0.0 ;
147+ bool theta2IsZero = theta2Opt.has_value () && theta2Opt.value () == 0.0 ;
148+
149+ auto loc = op.getLoc ();
150+
151+ // Special cases:
152+ //
153+ // 1. if (ω1 == 0 && ϕ2 == 0) { ϕF = ϕ1; θF = θ1 + θ2; ωF = ω2; }
154+ // 2a. if (θ1 == 0 && θ2 == 0) { ϕF = ϕ1 + ϕ2 + ω1 + ω2; θF = 0; ωF = 0; }
155+ // 2b. if (θ1 == 0) { ϕF = ϕ1 + ϕ2 + ω1; θF = θ2; ωF = ω2; }
156+ // 2c. if (θ2 == 0) { ϕF = ϕ1; θF = θ1; ωF = ω1 + ω2 + ϕ2; }
157+ auto zeroConst = rewriter.create <arith::ConstantOp>(loc, rewriter.getF64FloatAttr (0.0 ));
158+ if (omega1IsZero && phi2IsZero) {
159+ phiF = phi1;
160+ thetaF = rewriter.create <arith::AddFOp>(loc, theta1, theta2);
161+ omegaF = omega2;
162+ }
163+ else if (theta1IsZero && theta2IsZero) {
164+ phiF =
165+ rewriter.create <arith::AddFOp>(loc, rewriter.create <arith::AddFOp>(loc, phi1, phi2),
166+ rewriter.create <arith::AddFOp>(loc, omega1, omega2));
167+ thetaF = zeroConst;
168+ omegaF = zeroConst;
169+ }
170+ else if (theta1IsZero) {
171+ phiF = rewriter.create <arith::AddFOp>(
172+ loc, rewriter.create <arith::AddFOp>(loc, phi1, phi2), omega1);
173+ thetaF = theta2;
174+ omegaF = omega2;
175+ }
176+ else if (theta2IsZero) {
177+ phiF = phi1;
178+ thetaF = theta1;
179+ omegaF = rewriter.create <arith::AddFOp>(
180+ loc, rewriter.create <arith::AddFOp>(loc, omega1, omega2), phi2);
181+ }
182+ else {
183+ auto halfConst = rewriter.create <arith::ConstantOp>(loc, rewriter.getF64FloatAttr (0.5 ));
184+ auto twoConst = rewriter.create <arith::ConstantOp>(loc, rewriter.getF64FloatAttr (2.0 ));
185+
186+ // α1 = (ϕ1 + ω1)/2, α2 = (ϕ2 + ω2)/2
187+ // β1 = (ϕ1 - ω1)/2, β2 = (ϕ2 - ω2)/2
188+ auto alpha1 = rewriter.create <arith::MulFOp>(
189+ loc, rewriter.create <arith::AddFOp>(loc, phi1, omega1), halfConst);
190+ auto alpha2 = rewriter.create <arith::MulFOp>(
191+ loc, rewriter.create <arith::AddFOp>(loc, phi2, omega2), halfConst);
192+ auto beta1 = rewriter.create <arith::MulFOp>(
193+ loc, rewriter.create <arith::SubFOp>(loc, phi1, omega1), halfConst);
194+ auto beta2 = rewriter.create <arith::MulFOp>(
195+ loc, rewriter.create <arith::SubFOp>(loc, phi2, omega2), halfConst);
196+
197+ // c1 = cos(θ1/2), c2 = cos(θ2/2)
198+ // s1 = sin(θ1/2), s2 = sin(θ2/2)
199+ auto theta1Half = rewriter.create <arith::MulFOp>(loc, theta1, halfConst);
200+ auto c1 = rewriter.create <math::CosOp>(loc, theta1Half);
201+ auto s1 = rewriter.create <math::SinOp>(loc, theta1Half);
202+ auto theta2Half = rewriter.create <arith::MulFOp>(loc, theta2, halfConst);
203+ auto c2 = rewriter.create <math::CosOp>(loc, theta2Half);
204+ auto s2 = rewriter.create <math::SinOp>(loc, theta2Half);
205+
206+ // cF = sqrt(c1^2 * c2^2 +
207+ // s1^2 * s2^2 -
208+ // 2 * c1 * c2 * s1 * s2 * cos(ω1 + ϕ2))
209+ auto c1TimesC2 = rewriter.create <arith::MulFOp>(loc, c1, c2);
210+ auto s1TimesS2 = rewriter.create <arith::MulFOp>(loc, s1, s2);
211+ auto firstAddend =
212+ rewriter.create <arith::MulFOp>(loc, rewriter.create <arith::MulFOp>(loc, c1, c1),
213+ rewriter.create <arith::MulFOp>(loc, c2, c2));
214+ auto secondAddend =
215+ rewriter.create <arith::MulFOp>(loc, rewriter.create <arith::MulFOp>(loc, s1, s1),
216+ rewriter.create <arith::MulFOp>(loc, s2, s2));
217+ auto thirdAddend = rewriter.create <arith::NegFOp>(
218+ loc, rewriter.create <arith::MulFOp>(
219+ loc, twoConst,
220+ rewriter.create <arith::MulFOp>(
221+ loc, c1TimesC2,
222+ rewriter.create <arith::MulFOp>(
223+ loc, s1TimesS2,
224+ rewriter.create <math::CosOp>(
225+ loc, rewriter.create <arith::AddFOp>(loc, omega1, phi2))))));
226+ auto cF = rewriter.create <math::SqrtOp>(
227+ loc, rewriter.create <arith::AddFOp>(
228+ loc, firstAddend,
229+ rewriter.create <arith::AddFOp>(loc, secondAddend, thirdAddend)));
230+
231+ // TODO: can we check these problematic scenarios for differentiability by code?
232+ // Problematic scenarios for differentiability:
233+ //
234+ // 1. if (cF == 0) { /* sqrt not differentiable at 0 */ return failure(); }
235+ // 2. if (cF == 1) { /* acos not differentiable at 1 */ return failure(); }
236+
237+ // θF = 2 * acos(cF)
238+ auto acosCF = rewriter.create <math::AcosOp>(loc, cF);
239+ thetaF = rewriter.create <arith::MulFOp>(loc, twoConst, acosCF);
240+
241+ // αF = - atan((- c1 * c2 * sin(α1 + α2) - s1 * s2 * sin(β2 - β1)) /
242+ // ( c1 * c2 * cos(α1 + α2) - s1 * s2 * cos(β2 - β1)))
243+ auto alpha1PlusAlpha2 = rewriter.create <arith::AddFOp>(loc, alpha1, alpha2);
244+ auto beta2MinusBeta1 = rewriter.create <arith::SubFOp>(loc, beta2, beta1);
245+ auto term1 = rewriter.create <arith::NegFOp>(
246+ loc, rewriter.create <arith::MulFOp>(
247+ loc, c1TimesC2, rewriter.create <math::SinOp>(loc, alpha1PlusAlpha2)));
248+ auto term2 = rewriter.create <arith::NegFOp>(
249+ loc, rewriter.create <arith::MulFOp>(
250+ loc, s1TimesS2, rewriter.create <math::SinOp>(loc, beta2MinusBeta1)));
251+ auto term3 = rewriter.create <arith::MulFOp>(
252+ loc, c1TimesC2, rewriter.create <math::CosOp>(loc, alpha1PlusAlpha2));
253+ auto term4 = rewriter.create <arith::NegFOp>(
254+ loc, rewriter.create <arith::MulFOp>(
255+ loc, s1TimesS2, rewriter.create <math::CosOp>(loc, beta2MinusBeta1)));
256+ auto alphaF = rewriter.create <arith::NegFOp>(
257+ loc, rewriter.create <math::AtanOp>(
258+ loc, rewriter.create <arith::DivFOp>(
259+ loc, rewriter.create <arith::AddFOp>(loc, term1, term2),
260+ rewriter.create <arith::AddFOp>(loc, term3, term4))));
261+
262+ // βF = - atan((- c1 * s2 * sin(α1 + β2) + s1 * c2 * sin(α2 - β1)) /
263+ // ( c1 * s2 * cos(α1 + β2) + s1 * c2 * cos(α2 - β1)))
264+ auto c1TimesS2 = rewriter.create <arith::MulFOp>(loc, c1, s2);
265+ auto s1TimesC2 = rewriter.create <arith::MulFOp>(loc, s1, c2);
266+ auto alpha1PlusBeta2 = rewriter.create <arith::AddFOp>(loc, alpha1, beta2);
267+ auto alpha2MinusBeta1 = rewriter.create <arith::SubFOp>(loc, alpha2, beta1);
268+ auto term5 = rewriter.create <arith::NegFOp>(
269+ loc, rewriter.create <arith::MulFOp>(
270+ loc, c1TimesS2, rewriter.create <math::SinOp>(loc, alpha1PlusBeta2)));
271+ auto term6 = rewriter.create <arith::MulFOp>(
272+ loc, s1TimesC2, rewriter.create <math::SinOp>(loc, alpha2MinusBeta1));
273+ auto term7 = rewriter.create <arith::MulFOp>(
274+ loc, c1TimesS2, rewriter.create <math::CosOp>(loc, alpha1PlusBeta2));
275+ auto term8 = rewriter.create <arith::MulFOp>(
276+ loc, s1TimesC2, rewriter.create <math::CosOp>(loc, alpha2MinusBeta1));
277+ auto betaF = rewriter.create <arith::NegFOp>(
278+ loc, rewriter.create <math::AtanOp>(
279+ loc, rewriter.create <arith::DivFOp>(
280+ loc, rewriter.create <arith::AddFOp>(loc, term5, term6),
281+ rewriter.create <arith::AddFOp>(loc, term7, term8))));
282+
283+ // ϕF = αF + βF
284+ phiF = rewriter.create <arith::AddFOp>(loc, alphaF, betaF);
285+
286+ // ωF = αF - βF
287+ omegaF = rewriter.create <arith::SubFOp>(loc, alphaF, betaF);
288+ }
289+
290+ auto sumParams = SmallVector<mlir::Value>{phiF, thetaF, omegaF};
82291 auto mergeOp = rewriter.create <CustomOp>(loc, outQubitsTypes, outQubitsCtrlTypes, sumParams,
83- parentInQubits, opGateName , false ,
292+ parentInQubits, op. getGateName () , false ,
84293 parentInCtrlQubits, parentInCtrlValues);
85294
86295 rewriter.replaceOp (op, mergeOp);
87296 rewriter.eraseOp (parentOp);
88297
89298 return success ();
90299 }
300+
301+ LogicalResult matchAndRewrite (OpType op, PatternRewriter &rewriter) const override
302+ {
303+ LLVM_DEBUG (dbgs () << " Simplifying the following operation:\n " << op << " \n " );
304+
305+ StringRef opGateName = op.getGateName ();
306+ if (!fixedRotationsAndPhaseShiftsSet.contains (opGateName) &&
307+ !arbitraryRotationsSet.contains (opGateName)) {
308+ return failure ();
309+ }
310+
311+ VerifyHeterogeneousParentGateAndNameAnalysis<OpType, ParentOpType> vpga (op);
312+ if (!vpga.getVerifierResult ()) {
313+ return failure ();
314+ }
315+
316+ if (fixedRotationsAndPhaseShiftsSet.contains (opGateName)) {
317+ return matchAndRewriteFixedRotationOrPhaseShift (op, rewriter);
318+ }
319+ return matchAndRewriteArbitraryRotation (op, rewriter);
320+ }
91321};
92322
93- struct MergeMultiRZRewritePattern : public mlir :: OpRewritePattern<MultiRZOp> {
94- using mlir:: OpRewritePattern<MultiRZOp>::OpRewritePattern;
323+ struct MergeMultiRZRewritePattern : public OpRewritePattern <MultiRZOp> {
324+ using OpRewritePattern<MultiRZOp>::OpRewritePattern;
95325
96- mlir::LogicalResult matchAndRewrite (MultiRZOp op,
97- mlir::PatternRewriter &rewriter) const override
326+ LogicalResult matchAndRewrite (MultiRZOp op, PatternRewriter &rewriter) const override
98327 {
99328 LLVM_DEBUG (dbgs () << " Simplifying the following operation:\n " << op << " \n " );
100329 auto loc = op.getLoc ();
@@ -105,9 +334,7 @@ struct MergeMultiRZRewritePattern : public mlir::OpRewritePattern<MultiRZOp> {
105334 }
106335
107336 ValueRange inQubits = op.getInQubits ();
108- auto parentOp = dyn_cast_or_null<MultiRZOp>(inQubits[0 ].getDefiningOp ());
109- if (!parentOp)
110- return failure ();
337+ auto parentOp = llvm::cast<MultiRZOp>(inQubits[0 ].getDefiningOp ());
111338
112339 TypeRange outQubitsTypes = op.getOutQubits ().getTypes ();
113340 TypeRange outQubitsCtrlTypes = op.getOutCtrlQubits ().getTypes ();
0 commit comments