diff --git a/frontend/test/pytest/test_specs.py b/frontend/test/pytest/test_specs.py index 61abdcf3d..b01349c44 100644 --- a/frontend/test/pytest/test_specs.py +++ b/frontend/test/pytest/test_specs.py @@ -11,93 +11,408 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""Test for qml.specs() Catalyst integration""" +"""Tests for qml.specs() Catalyst integration""" + +from functools import partial + import pennylane as qml import pytest from jax import numpy as jnp +from pennylane.measurements import Shots +from pennylane.resource import CircuitSpecs, SpecsResources +import catalyst from catalyst import qjit # pylint:disable = protected-access,attribute-defined-outside-init -def check_specs_same(specs1, specs2): +def check_specs_header_same( + actual: CircuitSpecs, expected: CircuitSpecs, skip_level: bool = False +) -> None: + """Check that two specs dictionaries are the same.""" + assert actual["device_name"] == expected["device_name"] + assert actual["num_device_wires"] == expected["num_device_wires"] + if not skip_level: + assert actual["level"] == expected["level"] + assert actual["shots"] == expected["shots"] + + +# TODO: Remove this method once feature parity has been reached, and instead use `==` directly +def check_specs_resources_same( + actual_res: ( + SpecsResources | list[SpecsResources] | dict[any, SpecsResources | list[SpecsResources]] + ), + expected_res: ( + SpecsResources | list[SpecsResources] | dict[any, SpecsResources | list[SpecsResources]] + ), + skip_measurements: bool = False, +) -> None: + assert type(actual_res) == type(expected_res) + + if isinstance(actual_res, list): + assert len(actual_res) == len(expected_res) + + for r1, r2 in zip(actual_res, expected_res): + check_specs_resources_same(r1, r2, skip_measurements=skip_measurements) + + elif isinstance(actual_res, dict): + assert len(actual_res) == len(expected_res) + + for k in actual_res.keys(): + assert k in expected_res + check_specs_resources_same( + actual_res[k], expected_res[k], skip_measurements=skip_measurements + ) + + elif isinstance(actual_res, SpecsResources): + assert actual_res.gate_types == expected_res.gate_types + assert actual_res.gate_sizes == expected_res.gate_sizes + + # TODO: Measurements are not yet supported in Catalyst device-level specs + if not skip_measurements: + assert actual_res.measurements == expected_res.measurements + + assert actual_res.num_allocs == expected_res.num_allocs + assert actual_res.depth == expected_res.depth + assert actual_res.num_gates == expected_res.num_gates + + else: + raise ValueError("Invalid Type") + + +def check_specs_same(actual: CircuitSpecs, expected: CircuitSpecs, skip_measurements: bool = False): """Check that two specs dictionaries are the same.""" - assert specs1["device_name"] == specs2["device_name"] - assert specs1["resources"].num_wires == specs2["resources"].num_wires - assert specs1["resources"].num_gates == specs2["resources"].num_gates - assert specs1["resources"].depth == specs2["resources"].depth + check_specs_header_same(actual, expected) + check_specs_resources_same( + actual["resources"], expected["resources"], skip_measurements=skip_measurements + ) + + +class TestDeviceLevelSpecs: + """Test qml.specs() at device level""" + + def test_simple(self): + """Test a simple case of qml.specs() against PennyLane""" + + dev = qml.device("lightning.qubit", wires=1) + + @qml.qnode(dev) + def circuit(): + qml.Hadamard(wires=0) + return qml.expval(qml.PauliZ(0)) + + pl_specs = qml.specs(circuit, level="device")() + cat_specs = qml.specs(qjit(circuit), level="device")() + + assert cat_specs["device_name"] == "lightning.qubit" + check_specs_same(cat_specs, pl_specs, skip_measurements=True) + + def test_complex(self): + """Test a complex case of qml.specs() against PennyLane""" + + dev = qml.device("lightning.qubit", wires=4) + U = 1 / jnp.sqrt(2) * jnp.array([[1, 1], [1, -1]], dtype=jnp.complex128) + + @qml.qnode(dev) + def circuit(): + qml.PauliX(0) + qml.adjoint(qml.T)(0) + qml.ctrl(op=qml.S, control=[1], control_values=[1])(0) + qml.ctrl(op=qml.S, control=[1, 2], control_values=[1, 0])(0) + qml.ctrl(op=qml.adjoint(qml.Y), control=[2], control_values=[1])(0) + qml.CNOT([0, 1]) + + qml.QubitUnitary(U, wires=0) + qml.ControlledQubitUnitary(U, control_values=[1], wires=[1, 0]) + qml.adjoint(qml.QubitUnitary(U, wires=0)) + qml.adjoint(qml.ControlledQubitUnitary(U, control_values=[1, 1], wires=[1, 2, 0])) + + return qml.probs() + + pl_specs = qml.specs(circuit, level="device")() + cat_specs = qml.specs(qjit(circuit), level="device")() + + assert cat_specs["device_name"] == "lightning.qubit" + + # Catalyst will handle Adjoint(PauliY) == PauliY + assert "CY" in cat_specs["resources"].gate_types + cat_specs["resources"].gate_types["C(Adjoint(PauliY))"] = cat_specs["resources"].gate_types[ + "CY" + ] + del cat_specs["resources"].gate_types["CY"] + + check_specs_same(cat_specs, pl_specs, skip_measurements=True) + + +class TestPassByPassSpecs: + """Test qml.specs() pass-by-pass specs""" + + @pytest.fixture + def simple_circuit(self): + """Fixture for a circuit.""" + + @qml.qnode(qml.device("lightning.qubit", wires=2)) + def circ(): + qml.RX(1.0, 0) + qml.RX(2.0, 0) + qml.RZ(3.0, 1) + qml.RZ(4.0, 1) + qml.Hadamard(0) + qml.Hadamard(0) + qml.CNOT([0, 1]) + qml.CNOT([0, 1]) + return qml.probs() + + return circ + + @pytest.mark.usefixtures("use_capture") + def test_basic_passes_multi_level(self, simple_circuit): + """Test that when passes are applied, the circuit resources are updated accordingly.""" + + simple_circuit = qml.transforms.cancel_inverses(simple_circuit) + simple_circuit = qml.transforms.merge_rotations(simple_circuit) + + simple_circuit = qjit(simple_circuit) + + expected = CircuitSpecs( + device_name="lightning.qubit", + num_device_wires=2, + shots=Shots(None), + level=[ + "Before transforms", + "Before MLIR Passes (MLIR-0)", + "cancel-inverses (MLIR-1)", + "merge-rotations (MLIR-2)", + ], + resources={ + "Before transforms": SpecsResources( + gate_types={"RX": 2, "RZ": 2, "Hadamard": 2, "CNOT": 2}, + gate_sizes={1: 6, 2: 2}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + "Before MLIR Passes (MLIR-0)": SpecsResources( + gate_types={"RX": 2, "RZ": 2, "Hadamard": 2, "CNOT": 2}, + gate_sizes={1: 6, 2: 2}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + "cancel-inverses (MLIR-1)": SpecsResources( + gate_types={"RX": 2, "RZ": 2}, + gate_sizes={1: 4}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + "merge-rotations (MLIR-2)": SpecsResources( + gate_types={"RX": 1, "RZ": 1}, + gate_sizes={1: 2}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + }, + ) + + actual = qml.specs(simple_circuit, level="all")() + + check_specs_same(actual, expected) + + # Test resources at each level match individual specs calls + for i, res in enumerate(actual["resources"].values()): + single_level_specs = qml.specs(simple_circuit, level=i)() + check_specs_header_same(actual, single_level_specs, skip_level=True) + check_specs_resources_same(res, single_level_specs["resources"]) + + def test_marker(self, simple_circuit): + """Test that qml.marker can be used appropriately.""" + + simple_circuit = partial(qml.marker, level="m0")(simple_circuit) + simple_circuit = qml.transforms.cancel_inverses(simple_circuit) + simple_circuit = partial(qml.marker, level="m1")(simple_circuit) + simple_circuit = qml.transforms.merge_rotations(simple_circuit) + simple_circuit = partial(qml.marker, level="m2")(simple_circuit) + + simple_circuit = qjit(simple_circuit) + + expected = CircuitSpecs( + device_name="lightning.qubit", + num_device_wires=2, + shots=Shots(None), + level=["m0", "m1", "m2"], + resources={ + "m0": SpecsResources( + gate_types={"RX": 2, "RZ": 2, "Hadamard": 2, "CNOT": 2}, + gate_sizes={1: 6, 2: 2}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + "m1": SpecsResources( + gate_types={"RX": 2, "RZ": 2}, + gate_sizes={1: 4}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + "m2": SpecsResources( + gate_types={"RX": 1, "RZ": 1}, + gate_sizes={1: 2}, + measurements={"probs(all wires)": 1}, + num_allocs=2, + ), + }, + ) + + actual = qml.specs(simple_circuit, level=["m0", "m1", "m2"])() + + check_specs_same(actual, expected) + + @pytest.mark.usefixtures("use_both_frontend") + def test_reprs_match(self): + """Test that when no transforms are applied to a typical circuit, the "Before Transform" + and "Before MLIR Passes" representations match.""" + + dev = qml.device("lightning.qubit", wires=7) + + @qml.qnode(dev) + def circuit(): + qml.StatePrep(jnp.array([0, 1]), wires=0) - assert len(specs1["resources"].gate_types) == len(specs2["resources"].gate_types) - for gate, count in specs1["resources"].gate_types.items(): - assert gate in specs2["resources"].gate_types - assert count == specs2["resources"].gate_types[gate] + qml.Hadamard(wires=0) + qml.CNOT(wires=[0, 1]) - assert len(specs1["resources"].gate_sizes) == len(specs2["resources"].gate_sizes) - for gate, count in specs1["resources"].gate_sizes.items(): - assert gate in specs2["resources"].gate_sizes - assert count == specs2["resources"].gate_sizes[gate] + qml.GlobalPhase(jnp.pi / 4) + qml.MultiRZ(jnp.pi / 2, wires=[1, 2, 3]) + qml.ctrl(qml.T, control=0)(wires=3) + qml.ctrl(op=qml.IsingXX(0.5, wires=[5, 6]), control=range(5), control_values=[1] * 5) + qml.QubitUnitary(jnp.array([[1, 0], [0, 1j]]), wires=2) -@pytest.mark.parametrize("level", ["device"]) -def test_simple(level): - """Test a simple case of qml.specs() against PennyLane""" + coeffs = [0.2, -0.543] + obs = [qml.X(0) @ qml.Z(1), qml.Z(0) @ qml.Hadamard(2)] + ham = qml.ops.LinearCombination(coeffs, obs) - dev = qml.device("lightning.qubit", wires=1) + return ( + qml.expval(qml.PauliZ(0)), + qml.expval(ham), + qml.probs(wires=[0, 1]), + qml.state(), + ) - @qml.qnode(dev) - def circuit(): - qml.Hadamard(wires=0) - return qml.expval(qml.PauliZ(0)) + specs_device = qml.specs(circuit, level=0, compute_depth=False)() + specs_all = qml.specs(qjit(circuit), level="all", compute_depth=False)() - pl_specs = qml.specs(circuit, level=level)() - cat_specs = qml.specs(qjit(circuit), level=level)() + regular_pl = specs_device["resources"] + before_transforms = specs_all["resources"]["Before transforms"] + before_mlir = specs_all["resources"]["Before MLIR Passes (MLIR-0)"] - assert cat_specs["device_name"] == "lightning.qubit" - check_specs_same(pl_specs, cat_specs) + check_specs_resources_same(regular_pl, before_transforms) + check_specs_resources_same(before_transforms, before_mlir) + def test_split_non_commuting(self): + """Test that qml.transforms.split_non_commuting works as expected""" -@pytest.mark.parametrize("level", ["device"]) -def test_complex(level): - """Test a complex case of qml.specs() against PennyLane""" + @qml.transforms.cancel_inverses + @qml.transforms.split_non_commuting + @qml.qnode(qml.device("null.qubit", wires=3)) + def circuit(): + qml.H(0) + qml.X(0) + qml.X(0) + return qml.expval(qml.X(0)), qml.expval(qml.Y(0)), qml.expval(qml.Z(0)) - dev = qml.device("lightning.qubit", wires=4) - U = 1 / jnp.sqrt(2) * jnp.array([[1, 1], [1, -1]], dtype=jnp.complex128) + actual = qml.specs(qjit(circuit), level=range(3))() + expected = CircuitSpecs( + device_name="null.qubit", + num_device_wires=3, + shots=Shots(None), + level=[ + "Before transforms", + "split_non_commuting", + "cancel_inverses", + ], + resources={ + "Before transforms": SpecsResources( + gate_types={"Hadamard": 1, "PauliX": 2}, + gate_sizes={1: 3}, + measurements={"expval(PauliX)": 1, "expval(PauliY)": 1, "expval(PauliZ)": 1}, + num_allocs=1, + ), + "split_non_commuting": [ + SpecsResources( + gate_types={"Hadamard": 1, "PauliX": 2}, + gate_sizes={1: 3}, + measurements={"expval(PauliX)": 1}, + num_allocs=1, + ), + SpecsResources( + gate_types={"Hadamard": 1, "PauliX": 2}, + gate_sizes={1: 3}, + measurements={"expval(PauliY)": 1}, + num_allocs=1, + ), + SpecsResources( + gate_types={"Hadamard": 1, "PauliX": 2}, + gate_sizes={1: 3}, + measurements={"expval(PauliZ)": 1}, + num_allocs=1, + ), + ], + "cancel_inverses": [ + SpecsResources( + gate_types={"Hadamard": 1}, + gate_sizes={1: 1}, + measurements={"expval(PauliX)": 1}, + num_allocs=1, + ), + SpecsResources( + gate_types={"Hadamard": 1}, + gate_sizes={1: 1}, + measurements={"expval(PauliY)": 1}, + num_allocs=1, + ), + SpecsResources( + gate_types={"Hadamard": 1}, + gate_sizes={1: 1}, + measurements={"expval(PauliZ)": 1}, + num_allocs=1, + ), + ], + }, + ) - @qml.qnode(dev) - def circuit(): - qml.PauliX(0) - qml.adjoint(qml.T)(0) - qml.ctrl(op=qml.S, control=[1], control_values=[1])(0) - qml.ctrl(op=qml.S, control=[1, 2], control_values=[1, 0])(0) - qml.ctrl(op=qml.adjoint(qml.Y), control=[2], control_values=[1])(0) - qml.CNOT([0, 1]) + check_specs_same(actual, expected) - qml.QubitUnitary(U, wires=0) - qml.ControlledQubitUnitary(U, control_values=[1], wires=[1, 0]) - qml.adjoint(qml.QubitUnitary(U, wires=0)) - qml.adjoint(qml.ControlledQubitUnitary(U, control_values=[1, 1], wires=[1, 2, 0])) + @pytest.mark.usefixtures("use_capture") + def test_subroutine(self): + dev = qml.device("lightning.qubit", wires=3) - return qml.probs() + @catalyst.jax_primitives.subroutine + def subroutine(): + qml.Hadamard(wires=0) - pl_specs = qml.specs(circuit, level=level)() - cat_specs = qml.specs(qjit(circuit), level=level)() + @qml.qjit(autograph=True) + @qml.qnode(dev) + def circuit(): - assert cat_specs["device_name"] == "lightning.qubit" + for _ in range(3): + subroutine() - # Catalyst level specs should report the number of controls for multi-controlled gates - assert "2C(S)" in cat_specs["resources"].gate_types - cat_specs["resources"].gate_types["C(S)"] += cat_specs["resources"].gate_types["2C(S)"] - del cat_specs["resources"].gate_types["2C(S)"] + return qml.probs() - # Catalyst will handle Adjoint(PauliY) == PauliY - assert "CY" in cat_specs["resources"].gate_types - cat_specs["resources"].gate_types["C(Adjoint(PauliY))"] += cat_specs["resources"].gate_types[ - "CY" - ] - del cat_specs["resources"].gate_types["CY"] + actual = qml.specs(circuit, level=1)() + expected = CircuitSpecs( + device_name="lightning.qubit", + num_device_wires=3, + shots=Shots(None), + level=1, + resources=SpecsResources( + gate_types={"Hadamard": 3}, + gate_sizes={1: 3}, + measurements={"probs(all wires)": 1}, + num_allocs=3, + ), + ) - check_specs_same(pl_specs, cat_specs) + check_specs_same(actual, expected) if __name__ == "__main__":