diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 33fc9a1ec6..4db1a582ff 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -10,6 +10,7 @@ [(#2231)](https://github.com/PennyLaneAI/catalyst/pull/2231) [(#2234)](https://github.com/PennyLaneAI/catalyst/pull/2234) [(#2118)](https://github.com/PennyLaneAI/catalyst/pull/2218) + [(#2260)](https://github.com/PennyLaneAI/catalyst/pull/2260) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 2311b5e01d..6b97fe7e0f 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -14,6 +14,7 @@ """Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" +from collections import defaultdict from functools import singledispatch, singledispatchmethod from pennylane.measurements import MeasurementProcess @@ -50,6 +51,11 @@ def __init__(self, dag_builder: DAGBuilder) -> None: # Keep track of nesting clusters using a stack self._cluster_uid_stack: list[str] = [] + # Create a map of wire to node uid + # Keys represent static (int) or dynamic wires (str) + # Values represent the set of all node uids that are on that wire. + self._wire_to_node_uids: dict[str | int, set[str]] = defaultdict(set) + # Use counter internally for UID self._node_uid_counter: int = 0 self._cluster_uid_counter: int = 0 @@ -59,6 +65,7 @@ def _reset(self) -> None: self._cluster_uid_stack: list[str] = [] self._node_uid_counter: int = 0 self._cluster_uid_counter: int = 0 + self._wire_to_node_uids: dict[str | int, set[str]] = defaultdict(set) def construct(self, module: builtin.ModuleOp) -> None: """Constructs the DAG from the module. @@ -116,24 +123,42 @@ def _unitary( ) self._node_uid_counter += 1 + # Search through previous ops found on current wires and connect + prev_ops = set.union(*(self._wire_to_node_uids[wire] for wire in qml_op.wires)) + for prev_op in prev_ops: + self.dag_builder.add_edge(prev_op, node_uid) + + # Update affected wires to source from this node UID + for wire in qml_op.wires: + self._wire_to_node_uids[wire] = {node_uid} + @_visit_operation.register def _projective_measure_op(self, op: quantum.MeasureOp) -> None: """Handler for the single-qubit projective measurement operation.""" # Create PennyLane instance - meas = xdsl_to_qml_measurement(op) + qml_op = xdsl_to_qml_measurement(op) # Add node to current cluster node_uid = f"node{self._node_uid_counter}" self.dag_builder.add_node( uid=node_uid, - label=get_label(meas), + label=get_label(qml_op), cluster_uid=self._cluster_uid_stack[-1], # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) shape="record", ) self._node_uid_counter += 1 + # Search through previous ops found on current wires and connect + prev_ops = set.union(*(self._wire_to_node_uids[wire] for wire in qml_op.wires)) + for prev_op in prev_ops: + self.dag_builder.add_edge(prev_op, node_uid) + + # Update affected wires to source from this node UID + for wire in qml_op.wires: + self._wire_to_node_uids[wire] = {node_uid} + # ===================== # QUANTUM MEASUREMENTS # ===================== @@ -156,6 +181,10 @@ def _state_op(self, op: quantum.StateOp) -> None: ) self._node_uid_counter += 1 + for seen_wire, seen_nodes in self._wire_to_node_uids.items(): + for seen_node in seen_nodes: + self.dag_builder.add_edge(seen_node, node_uid, color="lightpink3") + @_visit_operation.register def _statistical_measurement_ops( self, @@ -178,6 +207,10 @@ def _statistical_measurement_ops( ) self._node_uid_counter += 1 + for wire in meas.wires: + for seen_node in self._wire_to_node_uids[wire]: + self.dag_builder.add_edge(seen_node, node_uid, color="lightpink3") + @_visit_operation.register def _visit_sample_and_probs_ops( self, @@ -204,6 +237,10 @@ def _visit_sample_and_probs_ops( ) self._node_uid_counter += 1 + for wire in meas.wires: + for seen_node in self._wire_to_node_uids[wire]: + self.dag_builder.add_edge(seen_node, node_uid, color="lightpink3") + # ============= # CONTROL FLOW # ============= @@ -261,10 +298,15 @@ def _if_op(self, operation: scf.IfOp): self._cluster_uid_stack.append(uid) self._cluster_uid_counter += 1 + # Save wires state before all of the branches + wire_map_before = self._wire_to_node_uids.copy() + region_wire_maps: list[dict[int | str, set[str]]] = [] + # Loop through each branch and visualize as a cluster num_regions = len(flattened_if_op) for i, (condition_ssa, region) in enumerate(flattened_if_op): + # Visualize with a cluster def _get_conditional_branch_label(i): if i == 0: return "if" @@ -285,9 +327,16 @@ def _get_conditional_branch_label(i): self._cluster_uid_stack.append(uid) self._cluster_uid_counter += 1 + # Make fresh wire map before going into region + self._wire_to_node_uids = wire_map_before.copy() + # Go recursively into the branch to process internals self._visit_region(region) + # Update branch wire maps + if self._wire_to_node_uids != wire_map_before: + region_wire_maps.append(self._wire_to_node_uids) + # Pop branch cluster after processing to ensure # logical branches are treated as 'parallel' self._cluster_uid_stack.pop() @@ -295,6 +344,25 @@ def _get_conditional_branch_label(i): # Pop IfOp cluster before leaving this handler self._cluster_uid_stack.pop() + # Check what wires were affected + affected_wires: set[str | int] = set(wire_map_before.keys()) + for region_wire_map in region_wire_maps: + affected_wires.update(region_wire_map.keys()) + + # Update state to be the union of all branch wire maps + final_wire_map = defaultdict(set) + for wire in affected_wires: + all_nodes: set = set() + for region_wire_map in region_wire_maps: + if not wire in region_wire_map: + # IfOp region didn't apply anything on this wire + # so default to node before the IfOp + all_nodes.update(wire_map_before.get(wire, set())) + else: + all_nodes.update(region_wire_map.get(wire, set())) + final_wire_map[wire] = all_nodes + self._wire_to_node_uids = final_wire_map + # ============ # DEVICE NODE # ============ @@ -349,6 +417,9 @@ def _func_return(self, operation: func.ReturnOp) -> None: # the FuncOp's scope and so we can pop the ID off the stack. self._cluster_uid_stack.pop() + # Clear seen wires as we are exiting a FuncOp (qnode) + self._wire_to_node_uids = defaultdict(set) + def _flatten_if_op(op: scf.IfOp) -> list[tuple[SSAValue | None, Region]]: """Recursively flattens a nested IfOp (if/elif/else chains).""" diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index b6c77fb3d7..69a74c4d9e 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -830,3 +830,28 @@ def my_circuit(): assert len(nodes) == 2 # Device node + operator assert nodes["node1"]["label"] == get_label(op) + + +class TestOperatorConnectivity: + """Tests that operators are properly connected.""" + + @pytest.mark.unit + def test_static_connection_within_cluster(self): + """Tests that connections can be made within the same cluster.""" + pass + + @pytest.mark.unit + def test_static_connection_through_clusters(self): + """Tests that connections can be made through nested clusters.""" + pass + + @pytest.mark.unit + def test_static_connection_through_conditional(self): + """Tests that connections through conditionals make sense.""" + pass + + +class TestTerminalMeasurementConnectivity: + """Test that terminal measurements connect properly.""" + + pass