diff --git a/src/bloqade/qasm2/rewrite/native_gates.py b/src/bloqade/qasm2/rewrite/native_gates.py index 7d4840c4..2fb9ee6a 100644 --- a/src/bloqade/qasm2/rewrite/native_gates.py +++ b/src/bloqade/qasm2/rewrite/native_gates.py @@ -279,7 +279,7 @@ def rewrite_cu3(self, node: uop.CU3) -> result.RewriteResult: lam = self._get_const_value(node.lam) phi = self._get_const_value(node.phi) - if not all((theta, phi, lam)): + if theta is None or lam is None or phi is None: return result.RewriteResult() # cirq.ControlledGate(u3(theta, lambda phi)) diff --git a/test/qasm2/test_native.py b/test/qasm2/test_native.py index f8e32101..6f613218 100644 --- a/test/qasm2/test_native.py +++ b/test/qasm2/test_native.py @@ -2,7 +2,6 @@ import textwrap import cirq -import cirq.contrib import cirq.testing import cirq.contrib.qasm_import as qasm_import import cirq.circuits.qasm_output as qasm_output @@ -86,6 +85,18 @@ def generator(n_tests: int): """ ) + yield textwrap.dedent( + """ + OPENQASM 2.0; + include "qelib1.inc"; + + qreg q[2]; + + cu3(0.0, 0.6, 3.141591) q[0],q[1]; + + """ + ) + rgen = np.random.RandomState(128) for num in range(n_tests): # Generate a new instance: @@ -117,3 +128,28 @@ def kernel(): cirq.testing.assert_allclose_up_to_global_phase( cirq.unitary(old_circuit), cirq.unitary(cirq_circuit), atol=1e-8 ) + + +def test_cu3_rewrite(): + prog = textwrap.dedent( + """ + OPENQASM 2.0; + include "qelib1.inc"; + + qreg q[2]; + + cu3(0.0, 0.6, 3.141591) q[0],q[1]; + + """ + ) + + @qasm2.main.add(qasm2.dialects.inline) + def kernel(): + qasm2.inline(prog) + + walk.Walk(RydbergGateSetRewriteRule(kernel.dialects)).rewrite(kernel.code) + + new_qasm2 = qasm2.emit.QASM2().emit_str(kernel) + + # simple-stupid test to see if the rewrite injected a bunch of new lines + assert new_qasm2.count("\n") > prog.count("\n")