From 46010b9c065c0fcfe7b6759785a4d4674f81bd0a Mon Sep 17 00:00:00 2001 From: Jake Zaia Date: Mon, 1 Dec 2025 20:23:39 +0000 Subject: [PATCH 1/8] Update tests to match new specs output types --- frontend/test/pytest/test_specs.py | 27 ++++++++++++++------------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py index 61abdcf3da..eb811cba38 100644 --- a/frontend/test/pytest/test_specs.py +++ b/frontend/test/pytest/test_specs.py @@ -11,7 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Test for qml.specs() Catalyst integration""" +"""Tests for qml.specs() Catalyst integration""" + import pennylane as qml import pytest from jax import numpy as jnp @@ -21,22 +22,22 @@ # pylint:disable = protected-access,attribute-defined-outside-init +# TODO: Remove this method once feature pairty has been reached, and instead use `==` directly def check_specs_same(specs1, specs2): """Check that two specs dictionaries are the same.""" assert specs1["device_name"] == specs2["device_name"] - assert specs1["resources"].num_wires == specs2["resources"].num_wires - assert specs1["resources"].num_gates == specs2["resources"].num_gates - assert specs1["resources"].depth == specs2["resources"].depth + assert specs1["num_device_wires"] == specs2["num_device_wires"] + assert specs1["shots"] == specs2["shots"] - assert len(specs1["resources"].gate_types) == len(specs2["resources"].gate_types) - for gate, count in specs1["resources"].gate_types.items(): - assert gate in specs2["resources"].gate_types - assert count == specs2["resources"].gate_types[gate] + assert specs1["resources"].gate_types == specs2["resources"].gate_types + assert specs1["resources"].gate_sizes == specs2["resources"].gate_sizes - assert len(specs1["resources"].gate_sizes) == len(specs2["resources"].gate_sizes) - for gate, count in specs1["resources"].gate_sizes.items(): - assert gate in specs2["resources"].gate_sizes - assert count == specs2["resources"].gate_sizes[gate] + # Measurements are not yet supported in Catalyst device-level specs + # assert specs1["resources"].measurements == specs2["resources"].measurements + + assert specs1["resources"].num_allocs == specs2["resources"].num_allocs + assert specs1["resources"].depth == specs2["resources"].depth + assert specs1["resources"].num_gates == specs2["resources"].num_gates @pytest.mark.parametrize("level", ["device"]) @@ -92,7 +93,7 @@ def circuit(): # Catalyst will handle Adjoint(PauliY) == PauliY assert "CY" in cat_specs["resources"].gate_types - cat_specs["resources"].gate_types["C(Adjoint(PauliY))"] += cat_specs["resources"].gate_types[ + cat_specs["resources"].gate_types["C(Adjoint(PauliY))"] = cat_specs["resources"].gate_types[ "CY" ] del cat_specs["resources"].gate_types["CY"] From a8bfd06e9d969db2e7661e6b7394923632c89fcd Mon Sep 17 00:00:00 2001 From: Jake Zaia Date: Wed, 3 Dec 2025 21:20:27 +0000 Subject: [PATCH 2/8] Slight change to accomodate new control qubit counting --- frontend/test/pytest/test_specs.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py index eb811cba38..5b81b2084b 100644 --- a/frontend/test/pytest/test_specs.py +++ b/frontend/test/pytest/test_specs.py @@ -86,11 +86,6 @@ def circuit(): assert cat_specs["device_name"] == "lightning.qubit" - # Catalyst level specs should report the number of controls for multi-controlled gates - assert "2C(S)" in cat_specs["resources"].gate_types - cat_specs["resources"].gate_types["C(S)"] += cat_specs["resources"].gate_types["2C(S)"] - del cat_specs["resources"].gate_types["2C(S)"] - # Catalyst will handle Adjoint(PauliY) == PauliY assert "CY" in cat_specs["resources"].gate_types cat_specs["resources"].gate_types["C(Adjoint(PauliY))"] = cat_specs["resources"].gate_types[ From 9c6c193dd1984e51aef9d8e6de49ca86c7be025c Mon Sep 17 00:00:00 2001 From: Jake Zaia Date: Thu, 4 Dec 2025 16:46:35 +0000 Subject: [PATCH 3/8] Testing --- frontend/test/pytest/test_specs.py | 112 ++++++++++++++++------------- 1 file changed, 64 insertions(+), 48 deletions(-) diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py index 5b81b2084b..93d6555b0a 100644 --- a/frontend/test/pytest/test_specs.py +++ b/frontend/test/pytest/test_specs.py @@ -23,78 +23,94 @@ # TODO: Remove this method once feature pairty has been reached, and instead use `==` directly -def check_specs_same(specs1, specs2): +def check_specs_same(specs1, specs2, skip_measurements=False): """Check that two specs dictionaries are the same.""" assert specs1["device_name"] == specs2["device_name"] assert specs1["num_device_wires"] == specs2["num_device_wires"] assert specs1["shots"] == specs2["shots"] - assert specs1["resources"].gate_types == specs2["resources"].gate_types - assert specs1["resources"].gate_sizes == specs2["resources"].gate_sizes + assert type(specs1["resources"]) == type(specs2["resources"]) - # Measurements are not yet supported in Catalyst device-level specs - # assert specs1["resources"].measurements == specs2["resources"].measurements + if not isinstance(specs1["resources"], dict): + all_res1 = {None: specs1["resources"]} + all_res2 = {None: specs2["resources"]} - assert specs1["resources"].num_allocs == specs2["resources"].num_allocs - assert specs1["resources"].depth == specs2["resources"].depth - assert specs1["resources"].num_gates == specs2["resources"].num_gates + else: + all_res1 = specs1["resources"] + all_res2 = specs2["resources"] + for res1, res2 in zip(all_res1.values(), all_res2.values()): + assert res1.gate_types == res2.gate_types + assert res1.gate_sizes == res2.gate_sizes -@pytest.mark.parametrize("level", ["device"]) -def test_simple(level): - """Test a simple case of qml.specs() against PennyLane""" + # TODO: Measurements are not yet supported in Catalyst device-level specs + if not skip_measurements: + assert res1.measurements == res2.measurements - dev = qml.device("lightning.qubit", wires=1) + assert res1.num_allocs == res2.num_allocs + assert res1.depth == res2.depth + assert res1.num_gates == res2.num_gates - @qml.qnode(dev) - def circuit(): - qml.Hadamard(wires=0) - return qml.expval(qml.PauliZ(0)) +class TestDeviceLevelSpecs: + """Test qml.specs() at device level""" - pl_specs = qml.specs(circuit, level=level)() - cat_specs = qml.specs(qjit(circuit), level=level)() + def test_simple(self): + """Test a simple case of qml.specs() against PennyLane""" - assert cat_specs["device_name"] == "lightning.qubit" - check_specs_same(pl_specs, cat_specs) + dev = qml.device("lightning.qubit", wires=1) + @qml.qnode(dev) + def circuit(): + qml.Hadamard(wires=0) + return qml.expval(qml.PauliZ(0)) -@pytest.mark.parametrize("level", ["device"]) -def test_complex(level): - """Test a complex case of qml.specs() against PennyLane""" + pl_specs = qml.specs(circuit, level="device")() + cat_specs = qml.specs(qjit(circuit), level="device")() - dev = qml.device("lightning.qubit", wires=4) - U = 1 / jnp.sqrt(2) * jnp.array([[1, 1], [1, -1]], dtype=jnp.complex128) + assert cat_specs["device_name"] == "lightning.qubit" + check_specs_same(pl_specs, cat_specs, skip_measurements=True) - @qml.qnode(dev) - def circuit(): - qml.PauliX(0) - qml.adjoint(qml.T)(0) - qml.ctrl(op=qml.S, control=[1], control_values=[1])(0) - qml.ctrl(op=qml.S, control=[1, 2], control_values=[1, 0])(0) - qml.ctrl(op=qml.adjoint(qml.Y), control=[2], control_values=[1])(0) - qml.CNOT([0, 1]) - qml.QubitUnitary(U, wires=0) - qml.ControlledQubitUnitary(U, control_values=[1], wires=[1, 0]) - qml.adjoint(qml.QubitUnitary(U, wires=0)) - qml.adjoint(qml.ControlledQubitUnitary(U, control_values=[1, 1], wires=[1, 2, 0])) + def test_complex(self): + """Test a complex case of qml.specs() against PennyLane""" - return qml.probs() + dev = qml.device("lightning.qubit", wires=4) + U = 1 / jnp.sqrt(2) * jnp.array([[1, 1], [1, -1]], dtype=jnp.complex128) - pl_specs = qml.specs(circuit, level=level)() - cat_specs = qml.specs(qjit(circuit), level=level)() + @qml.qnode(dev) + def circuit(): + qml.PauliX(0) + qml.adjoint(qml.T)(0) + qml.ctrl(op=qml.S, control=[1], control_values=[1])(0) + qml.ctrl(op=qml.S, control=[1, 2], control_values=[1, 0])(0) + qml.ctrl(op=qml.adjoint(qml.Y), control=[2], control_values=[1])(0) + qml.CNOT([0, 1]) - assert cat_specs["device_name"] == "lightning.qubit" + qml.QubitUnitary(U, wires=0) + qml.ControlledQubitUnitary(U, control_values=[1], wires=[1, 0]) + qml.adjoint(qml.QubitUnitary(U, wires=0)) + qml.adjoint(qml.ControlledQubitUnitary(U, control_values=[1, 1], wires=[1, 2, 0])) - # Catalyst will handle Adjoint(PauliY) == PauliY - assert "CY" in cat_specs["resources"].gate_types - cat_specs["resources"].gate_types["C(Adjoint(PauliY))"] = cat_specs["resources"].gate_types[ - "CY" - ] - del cat_specs["resources"].gate_types["CY"] + return qml.probs() - check_specs_same(pl_specs, cat_specs) + pl_specs = qml.specs(circuit, level="device")() + cat_specs = qml.specs(qjit(circuit), level="device")() + assert cat_specs["device_name"] == "lightning.qubit" + + # Catalyst will handle Adjoint(PauliY) == PauliY + assert "CY" in cat_specs["resources"].gate_types + cat_specs["resources"].gate_types["C(Adjoint(PauliY))"] = cat_specs["resources"].gate_types[ + "CY" + ] + del cat_specs["resources"].gate_types["CY"] + + check_specs_same(pl_specs, cat_specs, skip_measurements=True) + +class TestPassByPassSpecs: + """Test qml.specs() pass-by-pass specs""" + + pass if __name__ == "__main__": pytest.main(["-x", __file__]) From 621ca8628299c9c0bb2ed9dec8ec995d57f20f1f Mon Sep 17 00:00:00 2001 From: Jake Zaia Date: Thu, 4 Dec 2025 23:36:16 +0000 Subject: [PATCH 4/8] Add integration tests for specs --- frontend/test/pytest/test_specs.py | 179 ++++++++++++++++++++++++++--- 1 file changed, 160 insertions(+), 19 deletions(-) diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py index 93d6555b0a..83f851c0a0 100644 --- a/frontend/test/pytest/test_specs.py +++ b/frontend/test/pytest/test_specs.py @@ -13,33 +13,43 @@ # limitations under the License. """Tests for qml.specs() Catalyst integration""" +from functools import partial + import pennylane as qml import pytest from jax import numpy as jnp +from pennylane.measurements import Shots +from pennylane.resource import CircuitSpecs, SpecsResources from catalyst import qjit # pylint:disable = protected-access,attribute-defined-outside-init -# TODO: Remove this method once feature pairty has been reached, and instead use `==` directly -def check_specs_same(specs1, specs2, skip_measurements=False): +def check_specs_header_same( + actual: CircuitSpecs, expected: CircuitSpecs, skip_level: bool = False +) -> None: """Check that two specs dictionaries are the same.""" - assert specs1["device_name"] == specs2["device_name"] - assert specs1["num_device_wires"] == specs2["num_device_wires"] - assert specs1["shots"] == specs2["shots"] - - assert type(specs1["resources"]) == type(specs2["resources"]) - - if not isinstance(specs1["resources"], dict): - all_res1 = {None: specs1["resources"]} - all_res2 = {None: specs2["resources"]} + assert actual["device_name"] == expected["device_name"] + assert actual["num_device_wires"] == expected["num_device_wires"] + if not skip_level: + assert actual["level"] == expected["level"] + assert actual["shots"] == expected["shots"] - else: - all_res1 = specs1["resources"] - all_res2 = specs2["resources"] - for res1, res2 in zip(all_res1.values(), all_res2.values()): +# TODO: Remove this method once feature pairty has been reached, and instead use `==` directly +def check_specs_resources_same( + actual_res: SpecsResources | dict[any, SpecsResources], + expected_res: SpecsResources | dict[any, SpecsResources], + skip_measurements: bool = False, +) -> None: + assert type(actual_res) == type(expected_res) + + if not isinstance(actual_res, dict): + actual_res = {None: actual_res} + expected_res = {None: expected_res} + + for res1, res2 in zip(actual_res.values(), expected_res.values()): assert res1.gate_types == res2.gate_types assert res1.gate_sizes == res2.gate_sizes @@ -51,6 +61,15 @@ def check_specs_same(specs1, specs2, skip_measurements=False): assert res1.depth == res2.depth assert res1.num_gates == res2.num_gates + +def check_specs_same(actual: CircuitSpecs, expected: CircuitSpecs, skip_measurements: bool = False): + """Check that two specs dictionaries are the same.""" + check_specs_header_same(actual, expected) + check_specs_resources_same( + actual["resources"], expected["resources"], skip_measurements=skip_measurements + ) + + class TestDeviceLevelSpecs: """Test qml.specs() at device level""" @@ -68,8 +87,7 @@ def circuit(): cat_specs = qml.specs(qjit(circuit), level="device")() assert cat_specs["device_name"] == "lightning.qubit" - check_specs_same(pl_specs, cat_specs, skip_measurements=True) - + check_specs_same(cat_specs, pl_specs, skip_measurements=True) def test_complex(self): """Test a complex case of qml.specs() against PennyLane""" @@ -105,12 +123,135 @@ def circuit(): ] del cat_specs["resources"].gate_types["CY"] - check_specs_same(pl_specs, cat_specs, skip_measurements=True) + check_specs_same(cat_specs, pl_specs, skip_measurements=True) + +@pytest.mark.usefixtures("use_both_frontend") class TestPassByPassSpecs: """Test qml.specs() pass-by-pass specs""" - pass + @pytest.fixture + def simple_circuit(self): + """Fixture for a circuit.""" + + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def circ(): + qml.RX(1.0, 0) + qml.RX(2.0, 0) + qml.RZ(3.0, 1) + qml.RZ(4.0, 1) + qml.Hadamard(0) + qml.Hadamard(0) + qml.CNOT([0, 1]) + qml.CNOT([0, 1]) + return qml.probs() + + return circ + + def test_basic_passes_multi_level(self, simple_circuit): + """Test that when passes are applied, the circuit resources are updated accordingly.""" + + if not qml.capture.enabled(): + pytest.xfail("Catalyst transforms display twice when capture not enabled") + + simple_circuit = qml.transforms.cancel_inverses(simple_circuit) + simple_circuit = qml.transforms.merge_rotations(simple_circuit) + + simple_circuit = qjit(simple_circuit) + + expected = CircuitSpecs( + device_name="lightning.qubit", + num_device_wires=2, + shots=Shots(None), + level=[ + "Before transforms", + "Before MLIR Passes (MLIR-0)", + "cancel-inverses (MLIR-1)", + "merge-rotations (MLIR-2)", + ], + resources={ + "Before transforms": SpecsResources( + gate_types={"RX": 2, "RZ": 2, "Hadamard": 2, "CNOT": 2}, + gate_sizes={1: 6, 2: 2}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + "Before MLIR Passes (MLIR-0)": SpecsResources( + gate_types={"RX": 2, "RZ": 2, "Hadamard": 2, "CNOT": 2}, + gate_sizes={1: 6, 2: 2}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + "cancel-inverses (MLIR-1)": SpecsResources( + gate_types={"RX": 2, "RZ": 2}, + gate_sizes={1: 4}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + "merge-rotations (MLIR-2)": SpecsResources( + gate_types={"RX": 1, "RZ": 1}, + gate_sizes={1: 2}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + }, + ) + + actual = qml.specs(simple_circuit, level="all")() + + check_specs_same(actual, expected) + + # Test resources at each level match individual specs calls + for i, res in enumerate(actual["resources"].values()): + single_level_specs = qml.specs(simple_circuit, level=i)() + check_specs_header_same(actual, single_level_specs, skip_level=True) + check_specs_resources_same(res, single_level_specs["resources"]) + + def test_marker(self, simple_circuit): + """Test that qml.marker can be used appropriately.""" + + if qml.capture.enabled(): + pytest.xfail("qml.marker is not currently compatible with program capture") + + simple_circuit = partial(qml.marker, level="m0")(simple_circuit) + simple_circuit = qml.transforms.cancel_inverses(simple_circuit) + simple_circuit = partial(qml.marker, level="m1")(simple_circuit) + simple_circuit = qml.transforms.merge_rotations(simple_circuit) + simple_circuit = partial(qml.marker, level="m2")(simple_circuit) + + simple_circuit = qjit(simple_circuit) + + expected = CircuitSpecs( + device_name="lightning.qubit", + num_device_wires=2, + shots=Shots(None), + level=["m0", "m1", "m2"], + resources={ + "m0": SpecsResources( + gate_types={"RX": 2, "RZ": 2, "Hadamard": 2, "CNOT": 2}, + gate_sizes={1: 6, 2: 2}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + "m1": SpecsResources( + gate_types={"RX": 2, "RZ": 2}, + gate_sizes={1: 4}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + "m2": SpecsResources( + gate_types={"RX": 1, "RZ": 1}, + gate_sizes={1: 2}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + }, + ) + + actual = qml.specs(simple_circuit, level=["m0", "m1", "m2"])() + + check_specs_same(actual, expected) + if __name__ == "__main__": pytest.main(["-x", __file__]) From 227b0d568c83b5812d6fbd2fc60a3725c939c4e0 Mon Sep 17 00:00:00 2001 From: Jake Zaia Date: Fri, 5 Dec 2025 16:44:27 +0000 Subject: [PATCH 5/8] Check format matches --- frontend/test/pytest/test_specs.py | 39 ++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py index 83f851c0a0..e60b0ecebb 100644 --- a/frontend/test/pytest/test_specs.py +++ b/frontend/test/pytest/test_specs.py @@ -252,6 +252,45 @@ def test_marker(self, simple_circuit): check_specs_same(actual, expected) + def test_reprs_match(self): + """Test that when no transforms are applied to a typical circuit, the "Before Transform" + and "Before MLIR Passes" representations match.""" + + dev = qml.device("lightning.qubit", wires=4) + + @qml.qnode(dev) + def circuit(): + qml.StatePrep(jnp.array([0, 1]), wires=0) + + qml.Hadamard(wires=0) + qml.CNOT(wires=[0, 1]) + + qml.GlobalPhase(jnp.pi / 4) + qml.MultiRZ(jnp.pi / 2, wires=[1, 2, 3]) + qml.ctrl(qml.T, control=0)(wires=3) + + qml.QubitUnitary(jnp.array([[1, 0], [0, 1j]]), wires=2) + + coeffs = [0.2, -0.543] + obs = [qml.X(0) @ qml.Z(1), qml.Z(0) @ qml.Hadamard(2)] + ham = qml.ops.LinearCombination(coeffs, obs) + + return ( + qml.expval(qml.PauliZ(0)), + qml.expval(ham), + qml.probs(wires=[0, 1]), + qml.state(), + ) + + circuit = qjit(circuit) + + specs_all = qml.specs(circuit, level="all")() + + before_transforms = specs_all["resources"]["Before transforms"] + before_mlir = specs_all["resources"]["Before MLIR Passes (MLIR-0)"] + + check_specs_resources_same(before_transforms, before_mlir) + if __name__ == "__main__": pytest.main(["-x", __file__]) From 490772413055e56c779867db6e9aada2a89c02f6 Mon Sep 17 00:00:00 2001 From: Jake Zaia Date: Fri, 5 Dec 2025 16:46:45 +0000 Subject: [PATCH 6/8] Add extra test for regular PL output format --- frontend/test/pytest/test_specs.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py index e60b0ecebb..48206fe74e 100644 --- a/frontend/test/pytest/test_specs.py +++ b/frontend/test/pytest/test_specs.py @@ -282,13 +282,14 @@ def circuit(): qml.state(), ) - circuit = qjit(circuit) - - specs_all = qml.specs(circuit, level="all")() + specs_device = qml.specs(circuit, level=0, compute_depth=False)() + specs_all = qml.specs(qjit(circuit), level="all")() + regular_pl = specs_device["resources"] before_transforms = specs_all["resources"]["Before transforms"] before_mlir = specs_all["resources"]["Before MLIR Passes (MLIR-0)"] + check_specs_resources_same(regular_pl, before_transforms) check_specs_resources_same(before_transforms, before_mlir) From e5211c285bc6b67abc98b1b2fd1293325293e456 Mon Sep 17 00:00:00 2001 From: Jake Zaia Date: Fri, 5 Dec 2025 22:01:01 +0000 Subject: [PATCH 7/8] Add tests for split-non-commuting --- frontend/test/pytest/test_specs.py | 126 +++++++++++++++++++++++++---- 1 file changed, 111 insertions(+), 15 deletions(-) diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py index 48206fe74e..08153bdc6a 100644 --- a/frontend/test/pytest/test_specs.py +++ b/frontend/test/pytest/test_specs.py @@ -37,29 +37,47 @@ def check_specs_header_same( assert actual["shots"] == expected["shots"] -# TODO: Remove this method once feature pairty has been reached, and instead use `==` directly +# TODO: Remove this method once feature parity has been reached, and instead use `==` directly def check_specs_resources_same( - actual_res: SpecsResources | dict[any, SpecsResources], - expected_res: SpecsResources | dict[any, SpecsResources], + actual_res: ( + SpecsResources | list[SpecsResources] | dict[any, SpecsResources | list[SpecsResources]] + ), + expected_res: ( + SpecsResources | list[SpecsResources] | dict[any, SpecsResources | list[SpecsResources]] + ), skip_measurements: bool = False, ) -> None: assert type(actual_res) == type(expected_res) - if not isinstance(actual_res, dict): - actual_res = {None: actual_res} - expected_res = {None: expected_res} + if isinstance(actual_res, list): + assert len(actual_res) == len(expected_res) - for res1, res2 in zip(actual_res.values(), expected_res.values()): - assert res1.gate_types == res2.gate_types - assert res1.gate_sizes == res2.gate_sizes + for r1, r2 in zip(actual_res, expected_res): + check_specs_resources_same(r1, r2, skip_measurements=skip_measurements) + + elif isinstance(actual_res, dict): + assert len(actual_res) == len(expected_res) + + for k in actual_res.keys(): + assert k in expected_res + check_specs_resources_same( + actual_res[k], expected_res[k], skip_measurements=skip_measurements + ) + + elif isinstance(actual_res, SpecsResources): + assert actual_res.gate_types == expected_res.gate_types + assert actual_res.gate_sizes == expected_res.gate_sizes # TODO: Measurements are not yet supported in Catalyst device-level specs if not skip_measurements: - assert res1.measurements == res2.measurements + assert actual_res.measurements == expected_res.measurements + + assert actual_res.num_allocs == expected_res.num_allocs + assert actual_res.depth == expected_res.depth + assert actual_res.num_gates == expected_res.num_gates - assert res1.num_allocs == res2.num_allocs - assert res1.depth == res2.depth - assert res1.num_gates == res2.num_gates + else: + raise ValueError("Invalid Type") def check_specs_same(actual: CircuitSpecs, expected: CircuitSpecs, skip_measurements: bool = False): @@ -256,7 +274,7 @@ def test_reprs_match(self): """Test that when no transforms are applied to a typical circuit, the "Before Transform" and "Before MLIR Passes" representations match.""" - dev = qml.device("lightning.qubit", wires=4) + dev = qml.device("lightning.qubit", wires=7) @qml.qnode(dev) def circuit(): @@ -268,6 +286,7 @@ def circuit(): qml.GlobalPhase(jnp.pi / 4) qml.MultiRZ(jnp.pi / 2, wires=[1, 2, 3]) qml.ctrl(qml.T, control=0)(wires=3) + qml.ctrl(op=qml.IsingXX(0.5, wires=[5, 6]), control=range(5), control_values=[1] * 5) qml.QubitUnitary(jnp.array([[1, 0], [0, 1j]]), wires=2) @@ -283,7 +302,7 @@ def circuit(): ) specs_device = qml.specs(circuit, level=0, compute_depth=False)() - specs_all = qml.specs(qjit(circuit), level="all")() + specs_all = qml.specs(qjit(circuit), level="all", compute_depth=False)() regular_pl = specs_device["resources"] before_transforms = specs_all["resources"]["Before transforms"] @@ -292,6 +311,83 @@ def circuit(): check_specs_resources_same(regular_pl, before_transforms) check_specs_resources_same(before_transforms, before_mlir) + def test_split_non_commuting(self): + """Test that qml.transforms.split_non_commuting works as expected""" + + if qml.capture.enabled(): + pytest.xfail("split-non-commuting is not currently compatible with program capture") + + @qml.transforms.cancel_inverses + @qml.transforms.split_non_commuting + @qml.qnode(qml.device("null.qubit", wires=3)) + def circuit(): + qml.H(0) + qml.X(0) + qml.X(0) + return qml.expval(qml.X(0)), qml.expval(qml.Y(0)), qml.expval(qml.Z(0)) + + actual = qml.specs(qjit(circuit), level=range(3))() + expected = CircuitSpecs( + device_name="null.qubit", + num_device_wires=3, + shots=Shots(None), + level=[ + "Before transforms", + "split_non_commuting", + "cancel_inverses", + ], + resources={ + "Before transforms": SpecsResources( + gate_types={"Hadamard": 1, "PauliX": 2}, + gate_sizes={1: 3}, + measurements={"expval(PauliX)": 1, "expval(PauliY)": 1, "expval(PauliZ)": 1}, + num_allocs=1, + ), + "split_non_commuting": [ + SpecsResources( + gate_types={"Hadamard": 1, "PauliX": 2}, + gate_sizes={1: 3}, + measurements={"expval(PauliX)": 1}, + num_allocs=1, + ), + SpecsResources( + gate_types={"Hadamard": 1, "PauliX": 2}, + gate_sizes={1: 3}, + measurements={"expval(PauliY)": 1}, + num_allocs=1, + ), + SpecsResources( + gate_types={"Hadamard": 1, "PauliX": 2}, + gate_sizes={1: 3}, + measurements={"expval(PauliZ)": 1}, + num_allocs=1, + ), + ], + "cancel_inverses": [ + SpecsResources( + gate_types={"Hadamard": 1}, + gate_sizes={1: 1}, + measurements={"expval(PauliX)": 1}, + num_allocs=1, + ), + SpecsResources( + gate_types={"Hadamard": 1}, + gate_sizes={1: 1}, + measurements={"expval(PauliY)": 1}, + num_allocs=1, + ), + SpecsResources( + gate_types={"Hadamard": 1}, + gate_sizes={1: 1}, + measurements={"expval(PauliZ)": 1}, + num_allocs=1, + ), + ], + }, + ) + + check_specs_same(actual, expected) + if __name__ == "__main__": pytest.main(["-x", __file__]) From c0521ba1a2b1d61c58680de17e7e5241a156bd86 Mon Sep 17 00:00:00 2001 From: Jake Zaia Date: Fri, 5 Dec 2025 23:00:34 +0000 Subject: [PATCH 8/8] Add test for subroutine --- frontend/test/pytest/test_specs.py | 46 +++++++++++++++++++++++------- 1 file changed, 36 insertions(+), 10 deletions(-) diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py index 08153bdc6a..b01349c445 100644 --- a/frontend/test/pytest/test_specs.py +++ b/frontend/test/pytest/test_specs.py @@ -21,6 +21,7 @@ from pennylane.measurements import Shots from pennylane.resource import CircuitSpecs, SpecsResources +import catalyst from catalyst import qjit # pylint:disable = protected-access,attribute-defined-outside-init @@ -144,7 +145,6 @@ def circuit(): check_specs_same(cat_specs, pl_specs, skip_measurements=True) -@pytest.mark.usefixtures("use_both_frontend") class TestPassByPassSpecs: """Test qml.specs() pass-by-pass specs""" @@ -166,12 +166,10 @@ def circ(): return circ + @pytest.mark.usefixtures("use_capture") def test_basic_passes_multi_level(self, simple_circuit): """Test that when passes are applied, the circuit resources are updated accordingly.""" - if not qml.capture.enabled(): - pytest.xfail("Catalyst transforms display twice when capture not enabled") - simple_circuit = qml.transforms.cancel_inverses(simple_circuit) simple_circuit = qml.transforms.merge_rotations(simple_circuit) @@ -228,9 +226,6 @@ def test_basic_passes_multi_level(self, simple_circuit): def test_marker(self, simple_circuit): """Test that qml.marker can be used appropriately.""" - if qml.capture.enabled(): - pytest.xfail("qml.marker is not currently compatible with program capture") - simple_circuit = partial(qml.marker, level="m0")(simple_circuit) simple_circuit = qml.transforms.cancel_inverses(simple_circuit) simple_circuit = partial(qml.marker, level="m1")(simple_circuit) @@ -270,6 +265,7 @@ def test_marker(self, simple_circuit): check_specs_same(actual, expected) + @pytest.mark.usefixtures("use_both_frontend") def test_reprs_match(self): """Test that when no transforms are applied to a typical circuit, the "Before Transform" and "Before MLIR Passes" representations match.""" @@ -314,9 +310,6 @@ def circuit(): def test_split_non_commuting(self): """Test that qml.transforms.split_non_commuting works as expected""" - if qml.capture.enabled(): - pytest.xfail("split-non-commuting is not currently compatible with program capture") - @qml.transforms.cancel_inverses @qml.transforms.split_non_commuting @qml.qnode(qml.device("null.qubit", wires=3)) @@ -388,6 +381,39 @@ def circuit(): check_specs_same(actual, expected) + @pytest.mark.usefixtures("use_capture") + def test_subroutine(self): + dev = qml.device("lightning.qubit", wires=3) + + @catalyst.jax_primitives.subroutine + def subroutine(): + qml.Hadamard(wires=0) + + @qml.qjit(autograph=True) + @qml.qnode(dev) + def circuit(): + + for _ in range(3): + subroutine() + + return qml.probs() + + actual = qml.specs(circuit, level=1)() + expected = CircuitSpecs( + device_name="lightning.qubit", + num_device_wires=3, + shots=Shots(None), + level=1, + resources=SpecsResources( + gate_types={"Hadamard": 3}, + gate_sizes={1: 3}, + measurements={"probs(all wires)": 1}, + num_allocs=3, + ), + ) + + check_specs_same(actual, expected) + if __name__ == "__main__": pytest.main(["-x", __file__])