From 25aa04b8491a955ac2e51a10ae01165a1a4372c7 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 2 Dec 2025 16:15:45 -0500 Subject: [PATCH 01/14] cl --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 33fc9a1ec6..13a22d730a 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -10,6 +10,7 @@ [(#2231)](https://github.com/PennyLaneAI/catalyst/pull/2231) [(#2234)](https://github.com/PennyLaneAI/catalyst/pull/2234) [(#2118)](https://github.com/PennyLaneAI/catalyst/pull/2218) + [(#)]() * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) From d613c1b98965959da221405430132ebe8c117eff Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Tue, 2 Dec 2025 16:16:37 -0500 Subject: [PATCH 02/14] Apply suggestion from @andrijapau --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 13a22d730a..4db1a582ff 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -10,7 +10,7 @@ [(#2231)](https://github.com/PennyLaneAI/catalyst/pull/2231) [(#2234)](https://github.com/PennyLaneAI/catalyst/pull/2234) [(#2118)](https://github.com/PennyLaneAI/catalyst/pull/2218) - [(#)]() + [(#2260)](https://github.com/PennyLaneAI/catalyst/pull/2260) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) From 858f40eade2c37b67ccaa6a912205ba83202ebe2 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 2 Dec 2025 17:50:11 -0500 Subject: [PATCH 03/14] add good static connectivity --- .../visualization/construct_circuit_dag.py | 74 +++++++++++++++++-- 1 file changed, 68 insertions(+), 6 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 1f27dae700..03341af694 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -14,17 +14,18 @@ """Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" +from collections import defaultdict from functools import singledispatchmethod from xdsl.dialects import builtin, func, scf from xdsl.ir import Block, Operation, Region, SSAValue from catalyst.python_interface.dialects import catalyst, quantum -from catalyst.python_interface.visualization.dag_builder import DAGBuilder from catalyst.python_interface.inspection.xdsl_conversion import ( xdsl_to_qml_measurement, xdsl_to_qml_op, ) +from catalyst.python_interface.visualization.dag_builder import DAGBuilder class ConstructCircuitDAG: @@ -48,6 +49,11 @@ def __init__(self, dag_builder: DAGBuilder) -> None: # Keep track of nesting clusters using a stack self._cluster_uid_stack: list[str] = [] + # Create a map of wire to node uid + # Keys represent static (int) or dynamic wires (str) + # Values represent the set of all node uids that are on that wire. + self._wire_to_node_uid: dict[str | int, set[str]] = defaultdict(set) + def _reset(self) -> None: """Resets the instance.""" self._cluster_uid_stack: list[str] = [] @@ -94,14 +100,26 @@ def _unitary( ) -> None: """Generic handler for unitary gates.""" + # Create PennyLane instance qml_op = xdsl_to_qml_op(op) - # Build node on graph + + # Add node to current cluster + node_uid = f"node_{id(op)}" self.dag_builder.add_node( - uid=f"node_{id(op)}", + uid=node_uid, label=str(qml_op), cluster_uid=self._cluster_uid_stack[-1], ) + # Search through previous ops found on current wires and connect + prev_ops = set.union(*(self._wire_to_node_uid[wire] for wire in qml_op.wires)) + for prev_op in prev_ops: + self.dag_builder.add_edge(prev_op, node_uid) + + # Update affected wires to source from this node UID + for wire in qml_op.wires: + self._wire_to_node_uid[wire] = {node_uid} + # ===================== # QUANTUM MEASUREMENTS # ===================== @@ -111,13 +129,20 @@ def _state_op(self, op: quantum.StateOp) -> None: """Handler for the terminal state measurement operation.""" meas = xdsl_to_qml_measurement(op) + node_uid = f"node_{id(op)}" # Build node on graph self.dag_builder.add_node( - uid=f"node_{id(op)}", + uid=node_uid, label=str(meas), cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", ) + for seen_wire, seen_nodes in self._wire_to_node_uid.items(): + for seen_node in seen_nodes: + self.dag_builder.add_edge(seen_node, node_uid) + @_visit_operation.register def _statistical_measurement_ops( self, @@ -127,13 +152,20 @@ def _statistical_measurement_ops( obs_op = op.obs.owner meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) + node_uid = f"node_{id(op)}" # Build node on graph self.dag_builder.add_node( - uid=f"node_{id(op)}", + uid=node_uid, label=str(meas), cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", ) + for wire in meas.wires: + for seen_node in self._wire_to_node_uid[wire]: + self.dag_builder.add_edge(seen_node, node_uid) + @_visit_operation.register def _projective_measure_op(self, op: quantum.MeasureOp) -> None: """Handler for the single-qubit projective measurement operation.""" @@ -200,10 +232,14 @@ def _if_op(self, operation: scf.IfOp): ) self._cluster_uid_stack.append(uid) + # Save wires state before all of the branches + wire_map_before = self._wire_to_node_uid.copy() + region_wire_maps: list[dict[int | str, set[str]]] = [] + # Loop through each branch and visualize as a cluster num_regions = len(flattened_if_op) for i, (condition_ssa, region) in enumerate(flattened_if_op): - + # Visualize with a cluster def _get_conditional_branch_label(i): if i == 0: return "if ..." @@ -223,9 +259,16 @@ def _get_conditional_branch_label(i): ) self._cluster_uid_stack.append(uid) + # Make fresh wire map before going into region + self._wire_to_node_uid = wire_map_before.copy() + # Go recursively into the branch to process internals self._visit_region(region) + # Update branch wire maps + if self._wire_to_node_uid != wire_map_before: + region_wire_maps.append(self._wire_to_node_uid) + # Pop branch cluster after processing to ensure # logical branches are treated as 'parallel' self._cluster_uid_stack.pop() @@ -233,6 +276,25 @@ def _get_conditional_branch_label(i): # Pop IfOp cluster before leaving this handler self._cluster_uid_stack.pop() + # Check what wires were affected + affected_wires = set(wire_map_before.keys()) + for region_wire_map in region_wire_maps: + affected_wires.update(region_wire_map.keys()) + + # Update state to be the union of all branch wire maps + final_wire_map = defaultdict(set) + for wire in affected_wires: + all_nodes: set = set() + for region_wire_map in region_wire_maps: + if not wire in region_wire_map: + # Branch didn't touch this wire, so just use previous node + all_nodes.update(wire_map_before.get(wire, {})) + else: + all_nodes.update(region_wire_map.get(wire, {})) + + final_wire_map[wire] = all_nodes + self._wire_to_node_uid = final_wire_map + # ============ # DEVICE NODE # ============ From 3d0af6ab5553b26265e61044a64e2eb77220da85 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 2 Dec 2025 17:51:20 -0500 Subject: [PATCH 04/14] rename --- .../visualization/construct_circuit_dag.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 03341af694..19639fcbd3 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -52,7 +52,7 @@ def __init__(self, dag_builder: DAGBuilder) -> None: # Create a map of wire to node uid # Keys represent static (int) or dynamic wires (str) # Values represent the set of all node uids that are on that wire. - self._wire_to_node_uid: dict[str | int, set[str]] = defaultdict(set) + self._wire_to_node_uids: dict[str | int, set[str]] = defaultdict(set) def _reset(self) -> None: """Resets the instance.""" @@ -112,13 +112,13 @@ def _unitary( ) # Search through previous ops found on current wires and connect - prev_ops = set.union(*(self._wire_to_node_uid[wire] for wire in qml_op.wires)) + prev_ops = set.union(*(self._wire_to_node_uids[wire] for wire in qml_op.wires)) for prev_op in prev_ops: self.dag_builder.add_edge(prev_op, node_uid) # Update affected wires to source from this node UID for wire in qml_op.wires: - self._wire_to_node_uid[wire] = {node_uid} + self._wire_to_node_uids[wire] = {node_uid} # ===================== # QUANTUM MEASUREMENTS @@ -139,7 +139,7 @@ def _state_op(self, op: quantum.StateOp) -> None: color="lightpink3", ) - for seen_wire, seen_nodes in self._wire_to_node_uid.items(): + for seen_wire, seen_nodes in self._wire_to_node_uids.items(): for seen_node in seen_nodes: self.dag_builder.add_edge(seen_node, node_uid) @@ -163,7 +163,7 @@ def _statistical_measurement_ops( ) for wire in meas.wires: - for seen_node in self._wire_to_node_uid[wire]: + for seen_node in self._wire_to_node_uids[wire]: self.dag_builder.add_edge(seen_node, node_uid) @_visit_operation.register @@ -233,7 +233,7 @@ def _if_op(self, operation: scf.IfOp): self._cluster_uid_stack.append(uid) # Save wires state before all of the branches - wire_map_before = self._wire_to_node_uid.copy() + wire_map_before = self._wire_to_node_uids.copy() region_wire_maps: list[dict[int | str, set[str]]] = [] # Loop through each branch and visualize as a cluster @@ -260,14 +260,14 @@ def _get_conditional_branch_label(i): self._cluster_uid_stack.append(uid) # Make fresh wire map before going into region - self._wire_to_node_uid = wire_map_before.copy() + self._wire_to_node_uids = wire_map_before.copy() # Go recursively into the branch to process internals self._visit_region(region) # Update branch wire maps - if self._wire_to_node_uid != wire_map_before: - region_wire_maps.append(self._wire_to_node_uid) + if self._wire_to_node_uids != wire_map_before: + region_wire_maps.append(self._wire_to_node_uids) # Pop branch cluster after processing to ensure # logical branches are treated as 'parallel' @@ -293,7 +293,7 @@ def _get_conditional_branch_label(i): all_nodes.update(region_wire_map.get(wire, {})) final_wire_map[wire] = all_nodes - self._wire_to_node_uid = final_wire_map + self._wire_to_node_uids = final_wire_map # ============ # DEVICE NODE From 08082d900abcd1cd2504f8b21cb258d0e2d47a2e Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 2 Dec 2025 17:57:11 -0500 Subject: [PATCH 05/14] clean-up --- .../python_interface/visualization/construct_circuit_dag.py | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index cc7db9c90c..8a1b7fd184 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -246,6 +246,7 @@ def _if_op(self, operation: scf.IfOp): # Loop through each branch and visualize as a cluster num_regions = len(flattened_if_op) for i, (condition_ssa, region) in enumerate(flattened_if_op): + # Visualize with a cluster def _get_conditional_branch_label(i): if i == 0: From f23b5503441e0033f04545e630d527761b6a0b35 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 2 Dec 2025 18:07:09 -0500 Subject: [PATCH 06/14] refactor --- .../visualization/construct_circuit_dag.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 8a1b7fd184..b4a64e65e6 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -284,8 +284,8 @@ def _get_conditional_branch_label(i): # Pop IfOp cluster before leaving this handler self._cluster_uid_stack.pop() - # Check what wires were affected - affected_wires = set(wire_map_before.keys()) + # Check what wires were affected + affected_wires: set[str | int] = set(wire_map_before.keys()) for region_wire_map in region_wire_maps: affected_wires.update(region_wire_map.keys()) @@ -295,11 +295,11 @@ def _get_conditional_branch_label(i): all_nodes: set = set() for region_wire_map in region_wire_maps: if not wire in region_wire_map: - # Branch didn't touch this wire, so just use previous node + # IfOp region didn't apply anything on this wire + # so default to node before the IfOp all_nodes.update(wire_map_before.get(wire, {})) else: all_nodes.update(region_wire_map.get(wire, {})) - final_wire_map[wire] = all_nodes self._wire_to_node_uids = final_wire_map From 437c7d8b3e343aab295753ccc656f9171847933e Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 3 Dec 2025 14:05:23 -0500 Subject: [PATCH 07/14] make sure data is not carried over between qnodes --- .../python_interface/visualization/construct_circuit_dag.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index b4a64e65e6..a6553eea68 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -369,6 +369,8 @@ def _func_return(self, operation: func.ReturnOp) -> None: # the FuncOp's scope and so we can pop the ID off the stack. self._cluster_uid_stack.pop() + # Clear seen wires as we are exiting a FuncOp (qnode) + self._wire_to_node_uids = defaultdict(set) def _flatten_if_op(op: scf.IfOp) -> list[tuple[SSAValue | None, Region]]: """Recursively flattens a nested IfOp (if/elif/else chains).""" From 39d6ea6803d843df91f61f54ff70676297d2d9d6 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 3 Dec 2025 14:05:33 -0500 Subject: [PATCH 08/14] format --- .../python_interface/visualization/construct_circuit_dag.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index a6553eea68..659ab18dbf 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -284,7 +284,7 @@ def _get_conditional_branch_label(i): # Pop IfOp cluster before leaving this handler self._cluster_uid_stack.pop() - # Check what wires were affected + # Check what wires were affected affected_wires: set[str | int] = set(wire_map_before.keys()) for region_wire_map in region_wire_maps: affected_wires.update(region_wire_map.keys()) @@ -370,7 +370,8 @@ def _func_return(self, operation: func.ReturnOp) -> None: self._cluster_uid_stack.pop() # Clear seen wires as we are exiting a FuncOp (qnode) - self._wire_to_node_uids = defaultdict(set) + self._wire_to_node_uids = defaultdict(set) + def _flatten_if_op(op: scf.IfOp) -> list[tuple[SSAValue | None, Region]]: """Recursively flattens a nested IfOp (if/elif/else chains).""" From f25dd2ade71c5bb5dc32c594390b469dfbf07389 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Thu, 4 Dec 2025 10:43:52 -0500 Subject: [PATCH 09/14] Apply suggestion from @andrijapau --- .../python_interface/visualization/construct_circuit_dag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 103a8b0745..ced3c69a6b 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -302,9 +302,9 @@ def _get_conditional_branch_label(i): if not wire in region_wire_map: # IfOp region didn't apply anything on this wire # so default to node before the IfOp - all_nodes.update(wire_map_before.get(wire, {})) + all_nodes.update(wire_map_before.get(wire, set())) else: - all_nodes.update(region_wire_map.get(wire, {})) + all_nodes.update(region_wire_map.get(wire, set())) final_wire_map[wire] = all_nodes self._wire_to_node_uids = final_wire_map From 56c70c9416addfa3850cbb76cfe17d6c60c8fb75 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 4 Dec 2025 12:30:19 -0500 Subject: [PATCH 10/14] update logic --- .../python_interface/visualization/construct_circuit_dag.py | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 30a90592d5..353b8abb55 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -63,6 +63,7 @@ def _reset(self) -> None: self._cluster_uid_stack: list[str] = [] self._node_uid_counter: int = 0 self._cluster_uid_counter: int = 0 + self._wire_to_node_uids: dict[str | int, set[str]] = defaultdict(set) def construct(self, module: builtin.ModuleOp) -> None: """Constructs the DAG from the module. From aea60d050ba76fa198018fdc6c6c5ae5cd1acf7c Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 4 Dec 2025 17:02:05 -0500 Subject: [PATCH 11/14] add test skeeltons --- .../test_construct_circuit_dag.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 5438f9604d..a2feca9fbb 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -767,3 +767,28 @@ def my_circuit(): assert len(nodes) == 2 # Device node + operator assert "MidMeasure" in nodes["node1"]["label"] + +class TestOperatorConnectivity: + """Tests that operators are properly connected.""" + + @pytest.mark.unit + def test_static_connection_within_cluster(self): + """Tests that connections can be made within the same cluster.""" + pass + + @pytest.mark.unit + def test_static_connection_through_clusters(self): + """Tests that connections can be made through nested clusters.""" + pass + + @pytest.mark.unit + def test_static_connection_through_conditional(self): + """Tests that connections through conditionals make sense.""" + pass + + +class TestTerminalMeasurementConnectivity: + """Test that terminal measurements connect properly.""" + + pass + From b58602d8fed487e1bea77b2fd516d13554f3540d Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 16:43:13 -0500 Subject: [PATCH 12/14] format --- .../python_interface/visualization/construct_circuit_dag.py | 2 +- .../visualization/test_construct_circuit_dag.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index a22b1bb2a2..73160dd246 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -55,7 +55,7 @@ def __init__(self, dag_builder: DAGBuilder) -> None: # Keys represent static (int) or dynamic wires (str) # Values represent the set of all node uids that are on that wire. self._wire_to_node_uids: dict[str | int, set[str]] = defaultdict(set) - + # Use counter internally for UID self._node_uid_counter: int = 0 self._cluster_uid_counter: int = 0 diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 63ce502184..526984ddbc 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -831,6 +831,7 @@ def my_circuit(): assert "MidMeasure" in nodes["node1"]["label"] + class TestOperatorConnectivity: """Tests that operators are properly connected.""" @@ -854,4 +855,3 @@ class TestTerminalMeasurementConnectivity: """Test that terminal measurements connect properly.""" pass - From 68c2488cb6e6b4af4c40398b72b86ef0f7148f72 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 17:16:03 -0500 Subject: [PATCH 13/14] whoops --- .../visualization/construct_circuit_dag.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 73160dd246..28e15aef78 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -156,7 +156,7 @@ def _state_op(self, op: quantum.StateOp) -> None: for seen_wire, seen_nodes in self._wire_to_node_uids.items(): for seen_node in seen_nodes: - self.dag_builder.add_edge(seen_node, node_uid) + self.dag_builder.add_edge(seen_node, node_uid, color="lightpink3") @_visit_operation.register def _statistical_measurement_ops( @@ -182,7 +182,7 @@ def _statistical_measurement_ops( for wire in meas.wires: for seen_node in self._wire_to_node_uids[wire]: - self.dag_builder.add_edge(seen_node, node_uid) + self.dag_builder.add_edge(seen_node, node_uid, color="lightpink3") @_visit_operation.register def _visit_sample_and_probs_ops( @@ -210,6 +210,10 @@ def _visit_sample_and_probs_ops( ) self._node_uid_counter += 1 + for wire in meas.wires: + for seen_node in self._wire_to_node_uids[wire]: + self.dag_builder.add_edge(seen_node, node_uid, color="lightpink3") + @_visit_operation.register def _projective_measure_op(self, op: quantum.MeasureOp) -> None: """Handler for the single-qubit projective measurement operation.""" From fc3ea2cbf5c8a334be7761b0799dec376acb9ea9 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 18:11:42 -0500 Subject: [PATCH 14/14] add connectivity to MCMs --- .../visualization/construct_circuit_dag.py | 15 ++++++++++++--- .../visualization/test_construct_circuit_dag.py | 3 ++- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 22e5bf0d78..6b97fe7e0f 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -131,25 +131,34 @@ def _unitary( # Update affected wires to source from this node UID for wire in qml_op.wires: self._wire_to_node_uids[wire] = {node_uid} - + @_visit_operation.register def _projective_measure_op(self, op: quantum.MeasureOp) -> None: """Handler for the single-qubit projective measurement operation.""" # Create PennyLane instance - meas = xdsl_to_qml_measurement(op) + qml_op = xdsl_to_qml_measurement(op) # Add node to current cluster node_uid = f"node{self._node_uid_counter}" self.dag_builder.add_node( uid=node_uid, - label=get_label(meas), + label=get_label(qml_op), cluster_uid=self._cluster_uid_stack[-1], # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) shape="record", ) self._node_uid_counter += 1 + # Search through previous ops found on current wires and connect + prev_ops = set.union(*(self._wire_to_node_uids[wire] for wire in qml_op.wires)) + for prev_op in prev_ops: + self.dag_builder.add_edge(prev_op, node_uid) + + # Update affected wires to source from this node UID + for wire in qml_op.wires: + self._wire_to_node_uids[wire] = {node_uid} + # ===================== # QUANTUM MEASUREMENTS # ===================== diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 4c2763fc3d..69a74c4d9e 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -831,6 +831,7 @@ def my_circuit(): assert nodes["node1"]["label"] == get_label(op) + class TestOperatorConnectivity: """Tests that operators are properly connected.""" @@ -853,4 +854,4 @@ def test_static_connection_through_conditional(self): class TestTerminalMeasurementConnectivity: """Test that terminal measurements connect properly.""" - pass \ No newline at end of file + pass