diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index fcdb9f2ed5..33fc9a1ec6 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -9,6 +9,7 @@ [(#2246)](https://github.com/PennyLaneAI/catalyst/pull/2246) [(#2231)](https://github.com/PennyLaneAI/catalyst/pull/2231) [(#2234)](https://github.com/PennyLaneAI/catalyst/pull/2234) + [(#2118)](https://github.com/PennyLaneAI/catalyst/pull/2218) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 6a548b5d51..2311b5e01d 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -14,12 +14,18 @@ """Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" -from functools import singledispatchmethod +from functools import singledispatch, singledispatchmethod +from pennylane.measurements import MeasurementProcess +from pennylane.operation import Operator from xdsl.dialects import builtin, func, scf from xdsl.ir import Block, Operation, Region, SSAValue from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.inspection.xdsl_conversion import ( + xdsl_to_qml_measurement, + xdsl_to_qml_op, +) from catalyst.python_interface.visualization.dag_builder import DAGBuilder @@ -85,6 +91,119 @@ def _visit_block(self, block: Block) -> None: for op in block.ops: self._visit_operation(op) + # =================== + # QUANTUM OPERATIONS + # =================== + + @_visit_operation.register + def _unitary( + self, + op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, + ) -> None: + """Generic handler for unitary gates.""" + + # Create PennyLane instance + qml_op = xdsl_to_qml_op(op) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(qml_op), + cluster_uid=self._cluster_uid_stack[-1], + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + + @_visit_operation.register + def _projective_measure_op(self, op: quantum.MeasureOp) -> None: + """Handler for the single-qubit projective measurement operation.""" + + # Create PennyLane instance + meas = xdsl_to_qml_measurement(op) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + + # ===================== + # QUANTUM MEASUREMENTS + # ===================== + + @_visit_operation.register + def _state_op(self, op: quantum.StateOp) -> None: + """Handler for the terminal state measurement operation.""" + + # Create PennyLane instance + meas = xdsl_to_qml_measurement(op) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", + ) + self._node_uid_counter += 1 + + @_visit_operation.register + def _statistical_measurement_ops( + self, + op: quantum.ExpvalOp | quantum.VarianceOp, + ) -> None: + """Handler for statistical measurement operations.""" + + # Create PennyLane instance + obs_op = op.obs.owner + meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", + ) + self._node_uid_counter += 1 + + @_visit_operation.register + def _visit_sample_and_probs_ops( + self, + op: quantum.SampleOp | quantum.ProbsOp, + ) -> None: + """Handler for sample operations.""" + + # Create PennyLane instance + obs_op = op.obs.owner + + # TODO: This doesn't logically make sense, but quantum.compbasis + # is obs_op and function below just pulls out the static wires + wires = xdsl_to_qml_measurement(obs_op) + meas = xdsl_to_qml_measurement(op, wires=None if wires == [] else wires) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", + ) + self._node_uid_counter += 1 + # ============= # CONTROL FLOW # ============= @@ -254,3 +373,26 @@ def _flatten_if_op(op: scf.IfOp) -> list[tuple[SSAValue | None, Region]]: # with no SSAValue flattened_op.extend([(None, else_region)]) return flattened_op + + +@singledispatch +def get_label(op: Operator | MeasurementProcess) -> str: + """Gets the appropriate label for a PennyLane object.""" + return str(op) + + +@get_label.register +def _operator(op: Operator) -> str: + """Returns the appropriate label for an xDSL operation.""" + wires = list(op.wires.labels) + if wires == []: + wires_str = "all" + else: + wires_str = f"[{', '.join(map(str, wires))}]" + return f" {op.name}| {wires_str}" + + +@get_label.register +def _mp(mp: MeasurementProcess) -> str: + """Returns the appropriate label for an xDSL operation.""" + return str(mp) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 82f50140fa..0e90d66fdb 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -26,11 +26,14 @@ from xdsl.dialects.builtin import ModuleOp from xdsl.ir.core import Block, Region +from catalyst import measure from catalyst.python_interface.conversion import xdsl_from_qjit from catalyst.python_interface.visualization.construct_circuit_dag import ( ConstructCircuitDAG, + get_label, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder +from catalyst.utils.exceptions import CompileError class FakeDAGBuilder(DAGBuilder): @@ -283,10 +286,11 @@ def my_workflow(): # Assert null qubit device node is inside my_qnode2 cluster assert graph_clusters["cluster2"]["cluster_label"] == "my_qnode2" - assert graph_nodes["node1"]["parent_cluster_uid"] == "cluster2" + # NOTE: node1 is the qml.H(0) in my_qnode1 + assert graph_nodes["node2"]["parent_cluster_uid"] == "cluster2" # Assert label is as expected - assert graph_nodes["node1"]["label"] == "LightningSimulator" + assert graph_nodes["node2"]["label"] == "LightningSimulator" class TestForOp: @@ -538,5 +542,291 @@ def my_workflow(x, y): assert clusters["cluster6"]["parent_cluster_uid"] == "cluster4" # Check nested if / else is within the first if cluster - assert clusters["cluster7"]["node_label"] == "else" assert clusters["cluster7"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster7"]["node_label"] == "else" + + +class TestGetLabel: + """Tests the get_label utility.""" + + @pytest.mark.unit + @pytest.mark.parametrize( + "op", [qml.H(0), qml.QubitUnitary([[0, 1], [1, 0]], 0), qml.SWAP([0, 1])] + ) + def test_standard_operator(self, op): + """Tests against an operator instance.""" + wires = list(op.wires.labels) + if wires == []: + wires_str = "all" + else: + wires_str = f"[{', '.join(map(str, wires))}]" + + assert get_label(op) == f" {op.name}| {wires_str}" + + def test_global_phase_operator(self): + """Tests against a GlobalPhase operator instance.""" + assert get_label(qml.GlobalPhase(0.5)) == f" GlobalPhase| all" + + @pytest.mark.unit + @pytest.mark.parametrize( + "meas", + [ + qml.state(), + qml.expval(qml.Z(0)), + qml.var(qml.Z(0)), + qml.probs(), + qml.probs(wires=0), + qml.probs(wires=[0, 1]), + qml.sample(), + qml.sample(wires=0), + qml.sample(wires=[0, 1]), + ], + ) + def test_standard_measurement(self, meas): + """Tests against an operator instance.""" + + assert get_label(meas) == str(meas) + + +class TestCreateStaticOperatorNodes: + """Tests that operators with static parameters can be created and visualized as nodes.""" + + @pytest.mark.unit + @pytest.mark.parametrize("op", [qml.H(0), qml.X(0), qml.SWAP([0, 1])]) + def test_custom_op(self, op): + """Tests that the CustomOp operation node can be created and visualized.""" + + # Build module with only a CustomOp + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.apply(op) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + # Make sure label has relevant info + assert nodes["node1"]["label"] == get_label(op) + + @pytest.mark.unit + @pytest.mark.parametrize( + "op", + [ + qml.GlobalPhase(0.5), + qml.GlobalPhase(0.5, wires=0), + qml.GlobalPhase(0.5, wires=[0, 1]), + ], + ) + def test_global_phase_op(self, op): + """Test that GlobalPhase can be handled.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.apply(op) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + # Compiler throws out the wires and they get converted to wires=[] no matter what + assert nodes["node1"]["label"] == get_label(qml.GlobalPhase(0.5)) + + @pytest.mark.unit + def test_qubit_unitary_op(self): + """Test that QubitUnitary operations can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.QubitUnitary([[0, 1], [1, 0]], wires=0) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(qml.QubitUnitary([[0, 1], [1, 0]], wires=0)) + + @pytest.mark.unit + def test_multi_rz_op(self): + """Test that MultiRZ operations can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.MultiRZ(0.5, wires=[0]) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(qml.MultiRZ(0.5, wires=[0])) + + @pytest.mark.unit + def test_projective_measurement_op(self): + """Test that projective measurements can be captured as nodes.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + measure(0) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == f" MidMeasureMP| [0]" + + +class TestCreateStaticMeasurementNodes: + """Tests that measurements with static parameters can be created and visualized as nodes.""" + + @pytest.mark.unit + def test_state_op(self): + """Test that qml.state can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return qml.state() + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(qml.state()) + + @pytest.mark.unit + @pytest.mark.parametrize("meas_fn", [qml.expval, qml.var]) + def test_expval_var_measurement_op(self, meas_fn): + """Test that statistical measurement operators can be captured as nodes.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return meas_fn(qml.Z(0)) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(meas_fn(qml.Z(0))) + + @pytest.mark.unit + @pytest.mark.parametrize( + "op", + [ + qml.probs(), + qml.probs(wires=0), + qml.probs(wires=[0, 1]), + ], + ) + def test_probs_measurement_op(self, op): + """Tests that the probs measurement function can be captured as a node.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return op + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(op) + + @pytest.mark.unit + @pytest.mark.parametrize( + "op", + [ + qml.sample(), + qml.sample(wires=0), + qml.sample(wires=[0, 1]), + ], + ) + def test_valid_sample_measurement_op(self, op): + """Tests that the sample measurement function can be captured as a node.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.set_shots(10) + @qml.qnode(dev) + def my_circuit(): + return op + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == get_label(op)