Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,13 +306,14 @@
gradient dialect and the `lower-gradients` compilation stage.
[(#2241)](https://github.com/PennyLaneAI/catalyst/pull/2241)

* Added support for PPRs to the :func:`~.passes.merge_rotations` pass to merge PPRs with
equivalent angles, and cancelling of PPRs with opposite angles, or angles
that sum to identity. Also supports conditions on PPRs, merging when conditions are
identical and not merging otherwise.
* Added support for PPRs and arbitrary angle PPRs to the :func:`~.passes.merge_rotations` pass.
This pass now merges PPRs with equivalent angles, and cancels PPRs with opposite angles, or
angles that sum to identity when the angles are known. The pass also supports conditions on PPRs,
merging when conditions are identical and not merging otherwise.
[(#2224)](https://github.com/PennyLaneAI/catalyst/pull/2224)
[(#2245)](https://github.com/PennyLaneAI/catalyst/pull/2245)
[(#2254)](https://github.com/PennyLaneAI/catalyst/pull/2254)
[(#2258)](https://github.com/PennyLaneAI/catalyst/pull/2258)

<h3>Documentation 📝</h3>

Expand Down
27 changes: 27 additions & 0 deletions frontend/test/pytest/test_peephole_optimizations.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,6 +518,33 @@ def circuit():
assert 'qec.ppr ["X", "Y", "Z"](2)' in ir_opt


@pytest.mark.usefixtures("use_capture")
def test_merge_rotation_arbitrary_angle_ppr():
"""Test that the merge_rotation pass correctly merges arbtirary angle PPRs."""

my_pipeline = [("pipe", ["quantum-compilation-stage"])]

@qml.qjit(pipelines=my_pipeline, target="mlir")
def test_merge_rotation_ppr_workflow():
@qml.transforms.merge_rotations # have to use qml to be capture-compatible
@qml.qnode(qml.device("lightning.qubit", wires=2))
def circuit(x, y):
qml.PauliRot(x, pauli_word="ZY", wires=[0, 1])
qml.PauliRot(y, pauli_word="ZY", wires=[0, 1])

return circuit(2.6, 0.3)

ir = test_merge_rotation_ppr_workflow.mlir
ir_opt = test_merge_rotation_ppr_workflow.mlir_opt

assert 'transform.apply_registered_pass "merge-rotations"' in ir
assert "qec.ppr.arbitrary" in ir
assert "arith.addf" not in ir

assert "arith.addf" in ir_opt
assert 'qec.ppr.arbitrary ["Z", "Y"]' in ir_opt


def test_clifford_to_ppm():

pipe = [("pipe", ["quantum-compilation-stage"])]
Expand Down
79 changes: 79 additions & 0 deletions mlir/lib/Quantum/Transforms/MergeRotationsPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,24 @@ struct MergeRotationsRewritePattern : public OpRewritePattern<OpType> {
}
};

bool matchingQubitsAndPaulis(PPRotationArbitraryOp op, PPRotationArbitraryOp parentOp)
{
// construct map
llvm::DenseMap<mlir::Value, Attribute> qubit_to_pauli;
for (auto [qubit, pauli] : llvm::zip(parentOp.getOutQubits(), parentOp.getPauliProduct())) {
qubit_to_pauli[qubit] = pauli;
}

// check pairings
for (auto [qubit, pauli] : llvm::zip(op.getInQubits(), op.getPauliProduct())) {
if (qubit_to_pauli[qubit] != pauli) {
return false;
}
}

return true;
}

struct MergePPRRewritePattern : public OpRewritePattern<PPRotationOp> {
using OpRewritePattern::OpRewritePattern;

Expand Down Expand Up @@ -396,6 +414,66 @@ struct MergePPRRewritePattern : public OpRewritePattern<PPRotationOp> {
}
};

struct MergePPRArbitraryRewritePattern : public OpRewritePattern<PPRotationArbitraryOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(PPRotationArbitraryOp op,
PatternRewriter &rewriter) const override
{
ValueRange inQubits = op.getInQubits();
auto definingOp = inQubits[0].getDefiningOp();

if (!definingOp) {
return failure();
}

auto prevOp = dyn_cast<PPRotationArbitraryOp>(definingOp);

if (!prevOp) {
return failure();
}

// verify that prevOp agrees on all qubits, not just the first
for (auto qubit : inQubits) {
if (qubit.getDefiningOp() != prevOp) {
return failure();
}
}

// check that the same pauli operators are applied to the same qubits
if (!matchingQubitsAndPaulis(op, prevOp)) {
return failure();
}

// check same conditionals
if (op.getCondition() != prevOp.getCondition()) {
return failure();
}

auto opRotation = op.getArbitraryAngle();
auto prevOpRotation = prevOp.getArbitraryAngle();

if (!opRotation || !prevOpRotation) {
return failure();
}

// create merged op
auto loc = op.getLoc();
mlir::Value newAngleOp =
rewriter.create<arith::AddFOp>(loc, opRotation, prevOpRotation).getResult();

auto mergeOp = rewriter.create<PPRotationArbitraryOp>(
loc, op.getOutQubits().getTypes(), op.getPauliProduct(), newAngleOp,
prevOp.getInQubits(), op.getCondition());

// replace and erase old ops
rewriter.replaceOp(op, mergeOp);
rewriter.eraseOp(prevOp);

return success();
}
};

struct MergeMultiRZRewritePattern : public OpRewritePattern<MultiRZOp> {
using OpRewritePattern<MultiRZOp>::OpRewritePattern;

Expand Down Expand Up @@ -442,6 +520,7 @@ void populateMergeRotationsPatterns(RewritePatternSet &patterns)
patterns.add<MergeRotationsRewritePattern<CustomOp, CustomOp>>(patterns.getContext(), 1);
patterns.add<MergeMultiRZRewritePattern>(patterns.getContext(), 1);
patterns.add<MergePPRRewritePattern>(patterns.getContext(), 1);
patterns.add<MergePPRArbitraryRewritePattern>(patterns.getContext(), 1);
}

} // namespace quantum
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Quantum/Transforms/merge_rotation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ struct MergeRotationsPass : impl::MergeRotationsPassBase<MergeRotationsPass> {
&getContext());
catalyst::qec::PPRotationOp::getCanonicalizationPatterns(patternsCanonicalization,
&getContext());
catalyst::qec::PPRotationArbitraryOp::getCanonicalizationPatterns(patternsCanonicalization,
&getContext());
if (failed(applyPatternsGreedily(module, std::move(patternsCanonicalization)))) {
return signalPassFailure();
}
Expand Down
192 changes: 192 additions & 0 deletions mlir/test/Quantum/MergeRotationsTest.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -941,3 +941,195 @@ func.func public @dont_merge_conditionals(%q1: !quantum.bit, %q2: !quantum.bit,
func.return
}

// -----

// Arbitrary Angle PPR Tests

// simple merge

// CHECK-LABEL: merge_Y
func.func public @merge_Y(%q0: !quantum.bit, %0: f64, %1: f64) {
// CHECK: [[angle:%.+]] = arith.addf
// CHECK: qec.ppr.arbitrary ["Y"]([[angle]])
%2 = qec.ppr.arbitrary ["Y"](%0) %q0: !quantum.bit
%3 = qec.ppr.arbitrary ["Y"](%1) %2: !quantum.bit
func.return
}

// -----

// multiple merges

// CHECK-LABEL: merge_multi_Z
func.func public @merge_multi_Z(%q0: !quantum.bit, %0: f64, %1: f64, %2: f64) {
// CHECK: [[angle:%.+]] = arith.addf
// CHECK: [[angle2:%.+]] = arith.addf
// CHECK: qec.ppr.arbitrary ["Z"]([[angle2]])
// CHECK-NOT: qec.ppr.arbitrary
%3 = qec.ppr.arbitrary ["Z"](%0) %q0: !quantum.bit
%4 = qec.ppr.arbitrary ["Z"](%1) %3: !quantum.bit
%5 = qec.ppr.arbitrary ["Z"](%2) %4: !quantum.bit
func.return
}

// -----

// not merging when incompatible

// CHECK-LABEL: dont_merge
func.func public @dont_merge(%q0: !quantum.bit, %q1: !quantum.bit, %0: f64, %1: f64, %2: f64, %3: f64, %4: f64, %5: f64) {
// CHECK-NOT: arith.addf
// CHECK: qec.ppr.arbitrary ["Z", "X"]
// CHECK: qec.ppr.arbitrary ["Y", "X"]
// CHECK: qec.ppr.arbitrary ["Y", "Z"]
// CHECK: qec.ppr.arbitrary ["X", "Z"]
// CHECK: qec.ppr.arbitrary ["X", "Y"]
// CHECK: qec.ppr.arbitrary ["Z", "Y"]
%6:2 = qec.ppr.arbitrary ["Z", "X"](%0) %q0, %q1: !quantum.bit, !quantum.bit
%7:2 = qec.ppr.arbitrary ["Y", "X"](%1) %6#0, %6#1: !quantum.bit, !quantum.bit
%8:2 = qec.ppr.arbitrary ["Y", "Z"](%2) %7#0, %7#1: !quantum.bit, !quantum.bit
%9:2 = qec.ppr.arbitrary ["X", "Z"](%3) %8#0, %8#1: !quantum.bit, !quantum.bit
%10:2 = qec.ppr.arbitrary ["X", "Y"](%4) %9#0, %9#1: !quantum.bit, !quantum.bit
%11:2 = qec.ppr.arbitrary ["Z", "Y"](%5) %10#0, %10#1: !quantum.bit, !quantum.bit
func.return
}

// -----

// updating references

// CHECK-LABEL: merge_correct_references
func.func public @merge_correct_references(%q0: !quantum.bit, %0: f64, %1: f64, %2: f64, %3: f64) {
// CHECK-DAG: [[angle:%.+]] = arith.addf
// CHECK-DAG: [[in:%.+]] = qec.ppr.arbitrary ["X"]
// CHECK: [[out:%.+]] = qec.ppr.arbitrary ["Z"]([[angle]]) [[in]]
// CHECK: qec.ppr.arbitrary ["Y"]({{%.+}}) [[out]]
%4 = qec.ppr.arbitrary ["X"](%0) %q0: !quantum.bit
%5 = qec.ppr.arbitrary ["Z"](%1) %4: !quantum.bit
%6 = qec.ppr.arbitrary ["Z"](%2) %5: !quantum.bit
%7 = qec.ppr.arbitrary ["Y"](%3) %6: !quantum.bit
func.return
}

// -----

// multi-qubit merge

// CHECK-LABEL: merge_multi_XZY
func.func public @merge_multi_XZY(%q0: !quantum.bit, %q1: !quantum.bit, %q2: !quantum.bit, %0: f64, %1: f64, %2: f64) {
// CHECK: [[angle1:%.+]] = arith.addf
// CHECK: [[angle2:%.+]] = arith.addf
// CHECK: qec.ppr.arbitrary ["X", "Z", "Y"]([[angle2]])
%3:3 = qec.ppr.arbitrary ["X", "Z", "Y"](%0) %q0, %q1, %q2: !quantum.bit, !quantum.bit, !quantum.bit
%4:3 = qec.ppr.arbitrary ["X", "Z", "Y"](%1) %3#0, %3#1, %3#2: !quantum.bit, !quantum.bit, !quantum.bit
%5:3 = qec.ppr.arbitrary ["X", "Z", "Y"](%2) %4#0, %4#1, %4#2: !quantum.bit, !quantum.bit, !quantum.bit
func.return
}

// -----

// merge through other ops

// CHECK-LABEL: merge_through
func.func public @merge_through(%q0: !quantum.bit, %q1: !quantum.bit, %0: f64, %1: f64) -> !quantum.bit {
// CHECK-DAG: [[angle:%.+]] = arith.addf
// CHECK-DAG: quantum.custom
// CHECK-DAG: qec.ppr.arbitrary ["X"]([[angle]])
%2 = qec.ppr.arbitrary ["X"](%0) %q0: !quantum.bit
%3 = quantum.custom "Hadamard"() %q1: !quantum.bit
%4 = qec.ppr.arbitrary ["X"](%1) %2: !quantum.bit
func.return %3: !quantum.bit
}

// -----

// don't merge through other operations

// CHECK-LABEL: mixed_operations
func.func public @mixed_operations(%q0: !quantum.bit, %q1: !quantum.bit, %0: f64, %1: f64) {
// CHECK-NOT: arith.addf
// CHECK: qec.ppr.arbitrary ["Z", "X"]
// CHECK: quantum.custom
// CHECK: qec.ppr.arbitrary ["Z", "X"]
%2:2 = qec.ppr.arbitrary ["Z", "X"](%0) %q0, %q1: !quantum.bit, !quantum.bit
%3 = quantum.custom "Hadamard"() %2#1: !quantum.bit
%5:2 = qec.ppr.arbitrary ["Z", "X"](%1) %2#0, %3: !quantum.bit, !quantum.bit
func.return
}

// -----

// don't merge if only one qubit matches

// CHECK-LABEL: half_compatible_qubits
func.func public @half_compatible_qubits(%q0: !quantum.bit, %q1: !quantum.bit, %q2: !quantum.bit, %0: f64, %1: f64) {
// CHECK: qec.ppr.arbitrary ["X", "Z"]
%2:2 = qec.ppr.arbitrary ["X", "Z"](%0) %q0, %q1: !quantum.bit, !quantum.bit
%3:2 = qec.ppr.arbitrary ["X", "Z"](%1) %q2, %2#1 : !quantum.bit, !quantum.bit
func.return
}

// -----

// re-arranging qubits is ok as long as the pauli words are re-arranged too

// CHECK-LABEL: mix_and_match
func.func public @mix_and_match(%q0: !quantum.bit, %q1: !quantum.bit, %0: f64, %1: f64) {
// CHECK: [[angle:%.+]] = arith.addf
// CHECK: qec.ppr.arbitrary [{{.+}}]([[angle]])
%2:2 = qec.ppr.arbitrary ["Z", "Y"](%0) %q0, %q1: !quantum.bit, !quantum.bit
%3:2 = qec.ppr.arbitrary ["Y", "Z"](%1) %2#1, %2#0: !quantum.bit, !quantum.bit
func.return
}

// -----

// re-arranging qubits without re-arranging the pauli word is NOT okay

// CHECK-LABEL: mix_dont_match
func.func public @mix_dont_match(%q0: !quantum.bit, %q1: !quantum.bit, %0: f64, %1: f64) {
// CHECK: qec.ppr.arbitrary ["Y", "X"]
// CHECK: qec.ppr.arbitrary ["Y", "X"]
%2:2 = qec.ppr.arbitrary ["Y", "X"](%0) %q0, %q1: !quantum.bit, !quantum.bit
%3:2 = qec.ppr.arbitrary ["Y", "X"](%1) %2#1, %2#0: !quantum.bit, !quantum.bit
func.return
}

// -----

// check equivalent conditions are merged

// CHECK-LABEL: merge_condition
func.func public @merge_condition(%q0: !quantum.bit, %0: f64, %1: f64, %b0: i1) {
// CHECK: [[angle:%.+]] = arith.addf
// CHECK: qec.ppr.arbitrary ["X"]([[angle]]) {{%.+}} cond({{%.+}})
%2 = qec.ppr.arbitrary ["X"](%0) %q0 cond(%b0): !quantum.bit
%3 = qec.ppr.arbitrary ["X"](%1) %2 cond(%b0): !quantum.bit
func.return
}

// -----

// dont merge different conditions

// CHECK-LABEL: dont_merge_condition
func.func public @dont_merge_condition(%q0: !quantum.bit, %0: f64, %1: f64, %b0: i1, %b1: i1) {
// CHECK: [[in:%.+]] = qec.ppr.arbitrary ["X"]({{%.+}}) {{%.+}} cond({{%.+}})
// CHECK: qec.ppr.arbitrary ["X"]({{%.+}}) [[in]] cond({{%.+}})
%2 = qec.ppr.arbitrary ["X"](%0) %q0 cond(%b0): !quantum.bit
%3 = qec.ppr.arbitrary ["X"](%1) %2 cond(%b1): !quantum.bit
func.return
}

// -----

// don't merge conditions and non-conditions

// CHECK-LABEL: dont_merge_mixed
func.func public @dont_merge_mixed(%q0: !quantum.bit, %0: f64, %1: f64, %b0: i1) {
// CHECK: [[in:%.+]] = qec.ppr.arbitrary ["X"]({{%.+}}) {{%.+}} cond({{%.+}})
// CHECK: qec.ppr.arbitrary ["X"]({{%.+}}) [[in]]
%2 = qec.ppr.arbitrary ["X"](%0) %q0 cond(%b0): !quantum.bit
%3 = qec.ppr.arbitrary ["X"](%1) %2: !quantum.bit
func.return
}