Skip to content
Open
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,13 +319,14 @@
gradient dialect and the `lower-gradients` compilation stage.
[(#2241)](https://github.com/PennyLaneAI/catalyst/pull/2241)

* Added support for PPRs to the :func:`~.passes.merge_rotations` pass to merge PPRs with
equivalent angles, and cancelling of PPRs with opposite angles, or angles
that sum to identity. Also supports conditions on PPRs, merging when conditions are
identical and not merging otherwise.
[(#2224)](https://github.com/PennyLaneAI/catalyst/pull/2224)
* Added support for PPRs and arbitrary angle PPRs to the :func:`~.passes.merge_rotations` pass.
This pass now merges PPRs with equivalent angles, and cancels PPRs with opposite angles, or
angles that sum to identity when the angles are known. The pass also supports conditions on PPRs,
merging when conditions are identical and not merging otherwise.
[(#2224)](https://github.com/PennyLaneAI/catalyst/pull/2224)
[(#2245)](https://github.com/PennyLaneAI/catalyst/pull/2245)
[(#2254)](https://github.com/PennyLaneAI/catalyst/pull/2254)
[(#2258)](https://github.com/PennyLaneAI/catalyst/pull/2258)


* Refactor QEC tablegen files to separate QEC operations into a new `QECOp.td` file
Expand Down
27 changes: 27 additions & 0 deletions frontend/test/pytest/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,33 @@ def circuit():
assert 'qec.ppr ["X", "Y", "Z"](2)' in ir_opt


@pytest.mark.usefixtures("use_capture")
def test_merge_rotation_arbitrary_angle_ppr():
"""Test that the merge_rotation pass correctly merges arbtirary angle PPRs."""

my_pipeline = [("pipe", ["quantum-compilation-stage"])]

@qml.qjit(pipelines=my_pipeline, target="mlir")
def test_merge_rotation_ppr_workflow():
@qml.transforms.merge_rotations # have to use qml to be capture-compatible
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x, y):
qml.PauliRot(x, pauli_word="ZY", wires=[0, 1])
qml.PauliRot(y, pauli_word="ZY", wires=[0, 1])

return circuit(2.6, 0.3)

ir = test_merge_rotation_ppr_workflow.mlir
ir_opt = test_merge_rotation_ppr_workflow.mlir_opt

assert 'transform.apply_registered_pass "merge-rotations"' in ir
assert "qec.ppr.arbitrary" in ir
assert "arith.addf" not in ir

assert "arith.addf" in ir_opt
assert 'qec.ppr.arbitrary ["Z", "Y"]' in ir_opt


def test_clifford_to_ppm():

pipe = [("pipe", ["quantum-compilation-stage"])]
Expand Down
81 changes: 81 additions & 0 deletions mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,86 @@ struct MergePPRRewritePattern : public OpRewritePattern<PPRotationOp> {
}
};

struct MergePPRArbitraryRewritePattern : public OpRewritePattern<PPRotationArbitraryOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(PPRotationArbitraryOp op,
PatternRewriter &rewriter) const override
{
ValueRange opInQubits = op.getInQubits();

Operation *definingOp = opInQubits[0].getDefiningOp();
if (!definingOp) {
return failure();
}

auto parentOp = dyn_cast<PPRotationArbitraryOp>(definingOp);
if (!parentOp) {
return failure();
}

// verify that parentOp is parent of all qubits
for (mlir::Value qubit : opInQubits) {
if (qubit.getDefiningOp() != parentOp) {
return failure();
}
}
ValueRange parentOpOutQubits = parentOp.getOutQubits();
if (parentOpOutQubits.size() != opInQubits.size()) {
return failure();
}

// When two rotations have permuted Pauli strings, we can still merge them, we just need to
// correctly re-map the inputs. This map stores the index of a qubit in parentOp's out
// qubits at the index it appears in op's in qubits.
SmallVector<unsigned> inverse_permutation;
for (auto qubit : opInQubits) {
inverse_permutation.push_back(cast<mlir::OpResult>(qubit).getResultNumber());
}

// check Pauli + qubit pairings
mlir::ArrayAttr opPauliProduct = op.getPauliProduct();
mlir::ArrayAttr parentOpPauliProduct = parentOp.getPauliProduct();
for (size_t i = 0; i < opInQubits.size(); i++) {
if (opPauliProduct[i] != parentOpPauliProduct[inverse_permutation[i]]) {
return failure();
}
}

// check same conditionals
mlir::Value opCondition = op.getCondition();
if (opCondition != parentOp.getCondition()) {
return failure();
}

mlir::Location loc = op.getLoc();

mlir::Value opRotation = op.getArbitraryAngle();
mlir::Value parentOpRotation = parentOp.getArbitraryAngle();
mlir::Value newAngleOp =
rewriter.create<arith::AddFOp>(loc, opRotation, parentOpRotation).getResult();

// We need to construct the Pauli string + inQubits for new op. The simplest way to ensure
// that permuted PPRs can merge correctly is to maintain output qubits order and permute
// input qubits
mlir::ValueRange parentOpInQubits = parentOp.getInQubits();
SmallVector<mlir::Value> newInQubits;
for (size_t i = 0; i < parentOpInQubits.size(); i++) {
newInQubits.push_back(parentOpInQubits[inverse_permutation[i]]);
}

auto mergeOp = rewriter.create<PPRotationArbitraryOp>(loc, parentOpOutQubits.getTypes(),
opPauliProduct, newAngleOp,
newInQubits, opCondition);

// replace and erase old ops
rewriter.replaceOp(op, mergeOp);
rewriter.eraseOp(parentOp);

return success();
}
};

struct MergeMultiRZRewritePattern : public OpRewritePattern<MultiRZOp> {
using OpRewritePattern<MultiRZOp>::OpRewritePattern;

Expand Down Expand Up @@ -442,6 +522,7 @@ void populateMergeRotationsPatterns(RewritePatternSet &patterns)
patterns.add<MergeRotationsRewritePattern<CustomOp, CustomOp>>(patterns.getContext(), 1);
patterns.add<MergeMultiRZRewritePattern>(patterns.getContext(), 1);
patterns.add<MergePPRRewritePattern>(patterns.getContext(), 1);
patterns.add<MergePPRArbitraryRewritePattern>(patterns.getContext(), 1);
}

} // namespace quantum
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Quantum/Transforms/merge_rotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ struct MergeRotationsPass : impl::MergeRotationsPassBase<MergeRotationsPass> {
&getContext());
catalyst::qec::PPRotationOp::getCanonicalizationPatterns(patternsCanonicalization,
&getContext());
catalyst::qec::PPRotationArbitraryOp::getCanonicalizationPatterns(patternsCanonicalization,
&getContext());
if (failed(applyPatternsGreedily(module, std::move(patternsCanonicalization)))) {
return signalPassFailure();
}
Expand Down
218 changes: 218 additions & 0 deletions mlir/test/Quantum/MergeRotationsTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -941,3 +941,221 @@ func.func public @dont_merge_conditionals(%q1: !quantum.bit, %q2: !quantum.bit,
func.return
}

// -----

// Arbitrary Angle PPR Tests

// simple merge

// CHECK-LABEL: merge_Y
func.func public @merge_Y(%q0: !quantum.bit, %0: f64, %1: f64) {
// CHECK: [[angle:%.+]] = arith.addf
// CHECK: qec.ppr.arbitrary ["Y"]([[angle]])
%2 = qec.ppr.arbitrary ["Y"](%0) %q0: !quantum.bit
%3 = qec.ppr.arbitrary ["Y"](%1) %2: !quantum.bit
func.return
}

// -----

// multiple merges

// CHECK-LABEL: merge_multi_Z
func.func public @merge_multi_Z(%q0: !quantum.bit, %0: f64, %1: f64, %2: f64) {
// CHECK: [[angle:%.+]] = arith.addf
// CHECK: [[angle2:%.+]] = arith.addf
// CHECK: qec.ppr.arbitrary ["Z"]([[angle2]])
// CHECK-NOT: qec.ppr.arbitrary
%3 = qec.ppr.arbitrary ["Z"](%0) %q0: !quantum.bit
%4 = qec.ppr.arbitrary ["Z"](%1) %3: !quantum.bit
%5 = qec.ppr.arbitrary ["Z"](%2) %4: !quantum.bit
func.return
}

// -----

// not merging when incompatible

// CHECK-LABEL: dont_merge
func.func public @dont_merge(%q0: !quantum.bit, %q1: !quantum.bit, %0: f64, %1: f64, %2: f64, %3: f64, %4: f64, %5: f64) {
// CHECK-NOT: arith.addf
// CHECK: qec.ppr.arbitrary ["Z", "X"]
// CHECK: qec.ppr.arbitrary ["Y", "X"]
// CHECK: qec.ppr.arbitrary ["Y", "Z"]
// CHECK: qec.ppr.arbitrary ["X", "Z"]
// CHECK: qec.ppr.arbitrary ["X", "Y"]
// CHECK: qec.ppr.arbitrary ["Z", "Y"]
%6:2 = qec.ppr.arbitrary ["Z", "X"](%0) %q0, %q1: !quantum.bit, !quantum.bit
%7:2 = qec.ppr.arbitrary ["Y", "X"](%1) %6#0, %6#1: !quantum.bit, !quantum.bit
%8:2 = qec.ppr.arbitrary ["Y", "Z"](%2) %7#0, %7#1: !quantum.bit, !quantum.bit
%9:2 = qec.ppr.arbitrary ["X", "Z"](%3) %8#0, %8#1: !quantum.bit, !quantum.bit
%10:2 = qec.ppr.arbitrary ["X", "Y"](%4) %9#0, %9#1: !quantum.bit, !quantum.bit
%11:2 = qec.ppr.arbitrary ["Z", "Y"](%5) %10#0, %10#1: !quantum.bit, !quantum.bit
func.return
}

// -----

// updating references

// CHECK-LABEL: merge_correct_references
func.func public @merge_correct_references(%q0: !quantum.bit, %0: f64, %1: f64, %2: f64, %3: f64) {
// CHECK-DAG: [[angle:%.+]] = arith.addf
// CHECK-DAG: [[in:%.+]] = qec.ppr.arbitrary ["X"]
// CHECK: [[out:%.+]] = qec.ppr.arbitrary ["Z"]([[angle]]) [[in]]
// CHECK: qec.ppr.arbitrary ["Y"]({{%.+}}) [[out]]
%4 = qec.ppr.arbitrary ["X"](%0) %q0: !quantum.bit
%5 = qec.ppr.arbitrary ["Z"](%1) %4: !quantum.bit
%6 = qec.ppr.arbitrary ["Z"](%2) %5: !quantum.bit
%7 = qec.ppr.arbitrary ["Y"](%3) %6: !quantum.bit
func.return
}

// -----

// multi-qubit merge

// CHECK-LABEL: merge_multi_XZY
func.func public @merge_multi_XZY(%q0: !quantum.bit, %q1: !quantum.bit, %q2: !quantum.bit, %0: f64, %1: f64, %2: f64) {
// CHECK: [[angle1:%.+]] = arith.addf
// CHECK: [[angle2:%.+]] = arith.addf
// CHECK: qec.ppr.arbitrary ["X", "Z", "Y"]([[angle2]])
%3:3 = qec.ppr.arbitrary ["X", "Z", "Y"](%0) %q0, %q1, %q2: !quantum.bit, !quantum.bit, !quantum.bit
%4:3 = qec.ppr.arbitrary ["X", "Z", "Y"](%1) %3#0, %3#1, %3#2: !quantum.bit, !quantum.bit, !quantum.bit
%5:3 = qec.ppr.arbitrary ["X", "Z", "Y"](%2) %4#0, %4#1, %4#2: !quantum.bit, !quantum.bit, !quantum.bit
func.return
}

// -----

// merge through other ops

// CHECK-LABEL: merge_through
func.func public @merge_through(%q0: !quantum.bit, %q1: !quantum.bit, %0: f64, %1: f64) -> !quantum.bit {
// CHECK-DAG: [[angle:%.+]] = arith.addf
// CHECK-DAG: quantum.custom
// CHECK-DAG: qec.ppr.arbitrary ["X"]([[angle]])
%2 = qec.ppr.arbitrary ["X"](%0) %q0: !quantum.bit
%3 = quantum.custom "Hadamard"() %q1: !quantum.bit
%4 = qec.ppr.arbitrary ["X"](%1) %2: !quantum.bit
func.return %3: !quantum.bit
}

// -----

// don't merge through other operations

// CHECK-LABEL: mixed_operations
func.func public @mixed_operations(%q0: !quantum.bit, %q1: !quantum.bit, %0: f64, %1: f64) {
// CHECK-NOT: arith.addf
// CHECK: qec.ppr.arbitrary ["Z", "X"]
// CHECK: quantum.custom
// CHECK: qec.ppr.arbitrary ["Z", "X"]
%2:2 = qec.ppr.arbitrary ["Z", "X"](%0) %q0, %q1: !quantum.bit, !quantum.bit
%3 = quantum.custom "Hadamard"() %2#1: !quantum.bit
%5:2 = qec.ppr.arbitrary ["Z", "X"](%1) %2#0, %3: !quantum.bit, !quantum.bit
func.return
}

// -----

// don't merge if only one qubit matches

// CHECK-LABEL: half_compatible_qubits
func.func public @half_compatible_qubits(%q0: !quantum.bit, %q1: !quantum.bit, %q2: !quantum.bit, %0: f64, %1: f64) {
// CHECK: qec.ppr.arbitrary ["X", "Z"]
%2:2 = qec.ppr.arbitrary ["X", "Z"](%0) %q0, %q1: !quantum.bit, !quantum.bit
%3:2 = qec.ppr.arbitrary ["X", "Z"](%1) %q2, %2#1 : !quantum.bit, !quantum.bit
func.return
}

// -----

// re-arranging qubits is ok as long as the pauli words are re-arranged too

// CHECK-LABEL: merge_permutations
func.func public @merge_permutations(%z0: !quantum.bit, %y0: !quantum.bit, %0: f64, %1: f64, %2: f64, %3: f64) {
// CHECK-DAG: [[angle:%.+]] = arith.addf
// CHECK: qec.ppr.arbitrary ["Y", "Z"]([[angle]]) %arg1, %arg0
%z1, %y1 = qec.ppr.arbitrary ["Z", "Y"](%1) %z0, %y0: !quantum.bit, !quantum.bit
%y2, %z2 = qec.ppr.arbitrary ["Y", "Z"](%2) %y1, %z1: !quantum.bit, !quantum.bit
func.return
}

// -----

// check permutations with duplicate Pauli symbols

// CHECK-LABEL: merge_permutations_with_duplicates
func.func public @merge_permutations_with_duplicates(%q0: !quantum.bit, %q1: !quantum.bit, %q2: !quantum.bit, %0: f64, %1: f64, %2:f64) {
// CHECK: [[angle:%.+]] = arith.addf
// CHECK: qec.ppr.arbitrary ["Y", "X", "X"]([[angle]]) %arg1, %arg2, %arg0
%3:3 = qec.ppr.arbitrary ["X", "Y", "X"](%0) %q0, %q1, %q2: !quantum.bit, !quantum.bit, !quantum.bit
%4:3 = qec.ppr.arbitrary ["Y", "X", "X"](%1) %3#1, %3#2, %3#0: !quantum.bit, !quantum.bit, !quantum.bit
func.return
}

// -----

// re-arranging qubits without re-arranging the Pauli word is NOT okay

// CHECK-LABEL: dont_merge_permutations_qubits
func.func public @dont_merge_permutations_qubits(%q0: !quantum.bit, %q1: !quantum.bit, %0: f64, %1: f64) {
// CHECK: qec.ppr.arbitrary ["Y", "X"]
// CHECK: qec.ppr.arbitrary ["Y", "X"]
%2:2 = qec.ppr.arbitrary ["Y", "X"](%0) %q0, %q1: !quantum.bit, !quantum.bit
%3:2 = qec.ppr.arbitrary ["Y", "X"](%1) %2#1, %2#0: !quantum.bit, !quantum.bit
func.return
}

// -----

// re-arranging Pauli word without re-arranging qubits is not okay

// CHECK-LABEL: dont_merge_permutations_pauli
func.func public @dont_merge_permutations_pauli(%q0: !quantum.bit, %q1: !quantum.bit, %0: f64, %1: f64) {
// CHECK: qec.ppr.arbitrary ["Z", "Y"]
// CHECK: qec.ppr.arbitrary ["Y", "Z"]
%2:2 = qec.ppr.arbitrary ["Z", "Y"](%0) %q0, %q1: !quantum.bit, !quantum.bit
%3:2 = qec.ppr.arbitrary ["Y", "Z"](%1) %2#0, %2#1: !quantum.bit, !quantum.bit
return
}

// -----

// check equivalent conditions are merged

// CHECK-LABEL: merge_condition
func.func public @merge_condition(%q0: !quantum.bit, %0: f64, %1: f64, %b0: i1) {
// CHECK: [[angle:%.+]] = arith.addf
// CHECK: qec.ppr.arbitrary ["X"]([[angle]]) {{%.+}} cond({{%.+}})
%2 = qec.ppr.arbitrary ["X"](%0) %q0 cond(%b0): !quantum.bit
%3 = qec.ppr.arbitrary ["X"](%1) %2 cond(%b0): !quantum.bit
func.return
}

// -----

// dont merge different conditions

// CHECK-LABEL: dont_merge_condition
func.func public @dont_merge_condition(%q0: !quantum.bit, %0: f64, %1: f64, %b0: i1, %b1: i1) {
// CHECK: [[in:%.+]] = qec.ppr.arbitrary ["X"]({{%.+}}) {{%.+}} cond({{%.+}})
// CHECK: qec.ppr.arbitrary ["X"]({{%.+}}) [[in]] cond({{%.+}})
%2 = qec.ppr.arbitrary ["X"](%0) %q0 cond(%b0): !quantum.bit
%3 = qec.ppr.arbitrary ["X"](%1) %2 cond(%b1): !quantum.bit
func.return
}

// -----

// don't merge conditions and non-conditions

// CHECK-LABEL: dont_merge_mixed
func.func public @dont_merge_mixed(%q0: !quantum.bit, %0: f64, %1: f64, %b0: i1) {
// CHECK: [[in:%.+]] = qec.ppr.arbitrary ["X"]({{%.+}}) {{%.+}} cond({{%.+}})
// CHECK: qec.ppr.arbitrary ["X"]({{%.+}}) [[in]]
%2 = qec.ppr.arbitrary ["X"](%0) %q0 cond(%b0): !quantum.bit
%3 = qec.ppr.arbitrary ["X"](%1) %2: !quantum.bit
func.return
}