From 265cd4d56ce067832966cd77fe807cd658ce6da1 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 19 Nov 2025 13:28:34 -0500 Subject: [PATCH 001/111] copy files over --- .../visualization/dag_builder.py | 100 ++++++ .../visualization/pydot_dag_builder.py | 207 ++++++++++++ .../visualization/test_dag_builder.py | 94 ++++++ .../visualization/test_pydot_dag_builder.py | 300 ++++++++++++++++++ 4 files changed, 701 insertions(+) create mode 100644 frontend/catalyst/python_interface/visualization/dag_builder.py create mode 100644 frontend/catalyst/python_interface/visualization/pydot_dag_builder.py create mode 100644 frontend/test/pytest/python_interface/visualization/test_dag_builder.py create mode 100644 frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py new file mode 100644 index 0000000000..7a60a55f1e --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -0,0 +1,100 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""File that defines the DAGBuilder abstract base class.""" + +from abc import ABC, abstractmethod +from typing import Any + + +class DAGBuilder(ABC): + """An abstract base class for building Directed Acyclic Graphs (DAGs). + + This class provides a simple interface with three core methods (`add_node`, `add_edge` and `add_cluster`). + You can override these methods to implement any backend, like `pydot` or `graphviz` or even `matplotlib`. + + Outputting your graph can be done by overriding `to_file` and `to_string`. + """ + + @abstractmethod + def add_node( + self, node_id: str, node_label: str, parent_graph_id: str | None = None, **node_attrs: Any + ) -> None: + """Add a single node to the graph. + + Args: + node_id (str): Unique node ID to identify this node. + node_label (str): The text to display on the node when rendered. + parent_graph_id (str | None): Optional ID of the cluster this node belongs to. + **node_attrs (Any): Any additional styling keyword arguments. + + """ + raise NotImplementedError + + @abstractmethod + def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> None: + """Add a single directed edge between nodes in the graph. + + Args: + from_node_id (str): The unique ID of the source node. + to_node_id (str): The unique ID of the destination node. + **edge_attrs (Any): Any additional styling keyword arguments. + + """ + raise NotImplementedError + + @abstractmethod + def add_cluster( + self, + cluster_id: str, + cluster_label: str, + parent_graph_id: str | None = None, + **cluster_attrs: Any, + ) -> None: + """Add a single cluster to the graph. + + A cluster is a specific type of subgraph where the nodes and edges contained + within it are visually and logically grouped. + + Args: + cluster_id (str): Unique cluster ID to identify this cluster. + cluster_label (str): The text to display on the cluster when rendered. + parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. + **cluster_attrs (Any): Any additional styling keyword arguments. + + """ + raise NotImplementedError + + @abstractmethod + def to_file(self, output_filename: str) -> None: + """Save the graph to a file. + + The implementation should ideally infer the output format + (e.g., 'png', 'svg') from this filename's extension. + + Args: + output_filename (str): Desired filename for the graph. + + """ + raise NotImplementedError + + @abstractmethod + def to_string(self) -> str: + """Return the graph as a string. + + This is typically used to get the graph's representation in a standard string format like DOT. + + Returns: + str: A string representation of the graph. + """ + raise NotImplementedError diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py new file mode 100644 index 0000000000..ad9eef4bcb --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -0,0 +1,207 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""File that defines the PyDotDAGBuilder subclass of DAGBuilder.""" + +import pathlib +from collections import ChainMap +from typing import Any + +from .dag_builder import DAGBuilder + +has_pydot = True +try: + import pydot +except ImportError: + has_pydot = False + + +class PyDotDAGBuilder(DAGBuilder): + """A Directed Acyclic Graph builder for the PyDot backend.""" + + def __init__( + self, + attrs: dict | None = None, + node_attrs: dict | None = None, + edge_attrs: dict | None = None, + cluster_attrs: dict | None = None, + ) -> None: + """Initialize PyDotDAGBuilder instance. + + Args: + attrs (dict | None): User default attributes to be used for all elements (nodes, edges, clusters) in the graph. + node_attrs (dict | None): User default attributes for a node. + edge_attrs (dict | None): User default attributes for an edge. + cluster_attrs (dict | None): User default attributes for a cluster. + + """ + # Initialize the pydot graph: + # - graph_type="digraph": Create a directed graph (edges have arrows). + # - rankdir="TB": Set layout direction from Top to Bottom. + # - compound="true": Allow edges to connect directly to clusters/subgraphs. + # - strict=True": Prevent duplicate edges (e.g., A -> B added twice). + self.graph: pydot.Dot = pydot.Dot( + graph_type="digraph", rankdir="TB", compound="true", strict=True + ) + # Create cache for easy look-up + self._subgraphs: dict[str, pydot.Graph] = {} + self._subgraphs["__base__"] = self.graph + + _default_attrs: dict = {"fontname": "Helvetica", "penwidth": 2} if attrs is None else attrs + self._default_node_attrs: dict = ( + { + **_default_attrs, + "shape": "ellipse", + "style": "filled", + "fillcolor": "lightblue", + "color": "lightblue4", + "penwidth": 3, + } + if node_attrs is None + else node_attrs + ) + self._default_edge_attrs: dict = ( + { + "color": "lightblue4", + "penwidth": 3, + } + if edge_attrs is None + else edge_attrs + ) + self._default_cluster_attrs: dict = ( + { + **_default_attrs, + "shape": "rectangle", + "style": "solid", + } + if cluster_attrs is None + else cluster_attrs + ) + + def add_node( + self, + node_id: str, + node_label: str, + parent_graph_id: str | None = None, + **node_attrs: Any, + ) -> None: + """Add a single node to the graph. + + Args: + node_id (str): Unique node ID to identify this node. + node_label (str): The text to display on the node when rendered. + parent_graph_id (str | None): Optional ID of the cluster this node belongs to. + **node_attrs (Any): Any additional styling keyword arguments. + + """ + # Use ChainMap so you don't need to construct a new dictionary + node_attrs = ChainMap(node_attrs, self._default_node_attrs) + node = pydot.Node(node_id, label=node_label, **node_attrs) + parent_graph_id = "__base__" if parent_graph_id is None else parent_graph_id + + self._subgraphs[parent_graph_id].add_node(node) + + def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> None: + """Add a single directed edge between nodes in the graph. + + Args: + from_node_id (str): The unique ID of the source node. + to_node_id (str): The unique ID of the destination node. + **edge_attrs (Any): Any additional styling keyword arguments. + + """ + # Use ChainMap so you don't need to construct a new dictionary + edge_attrs = ChainMap(edge_attrs, self._default_edge_attrs) + edge = pydot.Edge(from_node_id, to_node_id, **edge_attrs) + self.graph.add_edge(edge) + + def add_cluster( + self, + cluster_id: str, + cluster_label: str, + parent_graph_id: str | None = None, + **cluster_attrs: Any, + ) -> None: + """Add a single cluster to the graph. + + A cluster is a specific type of subgraph where the nodes and edges contained + within it are visually and logically grouped. + + Args: + cluster_id (str): Unique cluster ID to identify this cluster. + cluster_label (str): The text to display on the cluster when rendered. + parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. + **cluster_attrs (Any): Any additional styling keyword arguments. + + """ + # Use ChainMap so you don't need to construct a new dictionary + cluster_attrs = ChainMap(cluster_attrs, self._default_cluster_attrs) + cluster = pydot.Cluster(graph_name=cluster_id, **cluster_attrs) + + # Puts the label in a node within the cluster. + # Ensures that any edges connecting nodes through the cluster + # boundary don't block the label. + # ┌───────────┐ + # │ ┌───────┐ │ + # │ │ label │ │ + # │ └───────┘ │ + # │ │ + # └───────────┘ + if cluster_label: + node_id = f"{cluster_id}_info_node" + rank_subgraph = pydot.Subgraph() + node = pydot.Node( + node_id, + label=cluster_label, + shape="rectangle", + style="dashed", + fontname="Helvetica", + penwidth=2, + ) + rank_subgraph.add_node(node) + cluster.add_subgraph(rank_subgraph) + cluster.add_node(node) + + self._subgraphs[cluster_id] = cluster + + parent_graph_id = "__base__" if parent_graph_id is None else parent_graph_id + self._subgraphs[parent_graph_id].add_subgraph(cluster) + + def to_file(self, output_filename: str) -> None: + """Save the graph to a file. + + This method will infer the file's format (e.g., 'png', 'svg') from this filename's extension. + If no extension is provided, the 'png' format will be the default. + + Args: + output_filename (str): Desired filename for the graph. File extension can be included + and if no file extension is provided, it will default to a `.png` file. + + """ + output_filename_path: pathlib.Path = pathlib.Path(output_filename) + if not output_filename_path.suffix: + output_filename_path = output_filename_path.with_suffix(".png") + + format = output_filename_path.suffix[1:].lower() + + self.graph.write(str(output_filename_path), format=format) + + def to_string(self) -> str: + """Return the graph as a string. + + This is typically used to get the graph's representation in a standard string format like DOT. + + Returns: + str: A string representation of the graph. + """ + return self.graph.to_string() diff --git a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py new file mode 100644 index 0000000000..3f70ae8159 --- /dev/null +++ b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py @@ -0,0 +1,94 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the DAGBuilder abstract base class.""" + +from typing import Any + +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +# pylint: disable=wrong-import-position +# This import needs to be after pytest in order to prevent ImportErrors +from catalyst.python_interface.visualization.dag_builder import DAGBuilder + + +def test_concrete_implementation_works(): + """Unit test for concrete implementation of abc.""" + + # pylint: disable=unused-argument + class ConcreteDAGBuilder(DAGBuilder): + """Concrete subclass of an ABC for testing purposes.""" + + def add_node( + self, + node_id: str, + node_label: str, + parent_graph_id: str | None = None, + **node_attrs: Any, + ) -> None: + return + + def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> None: + return + + def add_cluster( + self, + cluster_id: str, + cluster_label: str, + parent_graph_id: str | None = None, + **cluster_attrs: Any, + ) -> None: + return + + def to_file(self, output_filename: str) -> None: + return + + def to_string(self) -> str: + return "test" + + dag_builder = ConcreteDAGBuilder() + # pylint: disable = assignment-from-none + node = dag_builder.add_node("0", "node0") + edge = dag_builder.add_edge("0", "1") + cluster = dag_builder.add_cluster("0", "cluster0") + render = dag_builder.to_file("test.png") + string = dag_builder.to_string() + + assert node is None + assert edge is None + assert cluster is None + assert render is None + assert string == "test" + + +def test_abc_cannot_be_instantiated(): + """Tests that the DAGBuilder ABC cannot be instantiated.""" + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + # pylint: disable=abstract-class-instantiated + DAGBuilder() + + +def test_incomplete_subclass(): + """Tests that an incomplete subclass will fail""" + + # pylint: disable=too-few-public-methods + class IncompleteDAGBuilder(DAGBuilder): + def add_node(self, *args, **kwargs): + pass + + with pytest.raises(TypeError, match="Can't instantiate abstract class"): + # pylint: disable=abstract-class-instantiated + IncompleteDAGBuilder() diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py new file mode 100644 index 0000000000..5c57975a12 --- /dev/null +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -0,0 +1,300 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the PyDotDAGBuilder subclass.""" + +from unittest.mock import MagicMock + +import pytest + +pydot = pytest.importorskip("pydot") +pytestmark = pytest.mark.usefixtures("requires_xdsl") +# pylint: disable=wrong-import-position +from catalyst.python_interface.visualization.pydot_dag_builder import PyDotDAGBuilder + + +@pytest.mark.unit +def test_initialization_defaults(): + """Tests the default graph attributes are as expected.""" + + dag_builder = PyDotDAGBuilder() + + assert isinstance(dag_builder.graph, pydot.Dot) + # Ensure it's a directed graph + assert dag_builder.graph.get_graph_type() == "digraph" + # Ensure that it flows top to bottom + assert dag_builder.graph.get_rankdir() == "TB" + # Ensure edges can be connected directly to clusters / subgraphs + assert dag_builder.graph.get_compound() == "true" + # Ensure duplicated edges cannot be added + assert dag_builder.graph.obj_dict["strict"] is True + + +class TestAddMethods: + """Test that elements can be added to the graph.""" + + @pytest.mark.unit + def test_add_node(self): + """Unit test the `add_node` method.""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + node_list = dag_builder.graph.get_node_list() + assert len(node_list) == 1 + assert node_list[0].get_label() == "node0" + + @pytest.mark.unit + def test_add_edge(self): + """Unit test the `add_edge` method.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_node("0", "node0") + dag_builder.add_node("1", "node1") + dag_builder.add_edge("0", "1") + + assert len(dag_builder.graph.get_edges()) == 1 + edge = dag_builder.graph.get_edges()[0] + assert edge.get_source() == "0" + assert edge.get_destination() == "1" + + @pytest.mark.unit + def test_add_cluster(self): + """Unit test the 'add_cluster' method.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_cluster("0", "my_cluster") + + assert len(dag_builder.graph.get_subgraphs()) == 1 + assert dag_builder.graph.get_subgraphs()[0].get_name() == "cluster_0" + + @pytest.mark.unit + def test_add_node_to_parent_graph(self): + """Tests that you can add a node to a parent graph.""" + dag_builder = PyDotDAGBuilder() + + # Create node + dag_builder.add_node("0", "node0") + + # Create cluster + dag_builder.add_cluster("c0", "cluster0") + + # Create node inside cluster + dag_builder.add_node("1", "node1", parent_graph_id="c0") + + # Verify graph structure + root_graph = dag_builder.graph + + # Make sure the base graph has node0 + assert root_graph.get_node("0"), "Node 0 not found in root graph" + + # Get the cluster and verify it has node1 and not node0 + cluster_list = root_graph.get_subgraph("cluster_c0") + assert cluster_list, "Subgraph 'cluster_c0' not found" + cluster_graph = cluster_list[0] # Get the actual subgraph object + + assert cluster_graph.get_node("1"), "Node 1 not found in cluster 'c0'" + assert not cluster_graph.get_node("0"), ( + "Node 0 was incorrectly added to cluster" + ) + + assert not root_graph.get_node("1"), "Node 1 was incorrectly added to root" + + @pytest.mark.unit + def test_add_cluster_to_parent_graph(self): + """Test that you can add a cluster to a parent graph.""" + dag_builder = PyDotDAGBuilder() + + # Level 0 (Root): Adds cluster on top of base graph + dag_builder.add_node("n_root", "node_root") + dag_builder.add_cluster("c0", "cluster_outer") + + # Level 1 (Inside c0): Add node on outer cluster and create new cluster on top + dag_builder.add_node("n_outer", "node_outer", parent_graph_id="c0") + dag_builder.add_cluster("c1", "cluster_inner", parent_graph_id="c0") + + # Level 2 (Inside c1): Add node on second cluster + dag_builder.add_node("n_inner", "node_inner", parent_graph_id="c1") + + root_graph = dag_builder.graph + + outer_cluster_list = root_graph.get_subgraph("cluster_c0") + assert outer_cluster_list, "Outer cluster 'c0' not found in root" + c0 = outer_cluster_list[0] + + inner_cluster_list = c0.get_subgraph("cluster_c1") + assert inner_cluster_list, "Inner cluster 'c1' not found in 'c0'" + c1 = inner_cluster_list[0] + + # Check Level 0 (Root) + assert root_graph.get_node("n_root"), "n_root not found in root" + assert root_graph.get_subgraph("cluster_c0"), "c0 not found in root" + assert not root_graph.get_node("n_outer"), "n_outer incorrectly found in root" + assert not root_graph.get_node("n_inner"), "n_inner incorrectly found in root" + assert not root_graph.get_subgraph("cluster_c1"), "c1 incorrectly found in root" + + # Check Level 1 (c0) + assert c0.get_node("n_outer"), "n_outer not found in c0" + assert c0.get_subgraph("cluster_c1"), "c1 not found in c0" + assert not c0.get_node("n_root"), "n_root incorrectly found in c0" + assert not c0.get_node("n_inner"), "n_inner incorrectly found in c0" + + # Check Level 2 (c1) + assert c1.get_node("n_inner"), "n_inner not found in c1" + assert not c1.get_node("n_root"), "n_root incorrectly found in c1" + assert not c1.get_node("n_outer"), "n_outer incorrectly found in c1" + + +class TestAttributes: + """Tests that the attributes for elements in the graph are overridden correctly.""" + + @pytest.mark.unit + def test_default_graph_attrs(self): + """Test that default graph attributes can be set.""" + + dag_builder = PyDotDAGBuilder(attrs={"fontname": "Times"}) + + dag_builder.add_node("0", "node0") + node0 = dag_builder.graph.get_node("0")[0] + assert node0.get("fontname") == "Times" + + dag_builder.add_cluster("1", "cluster0") + cluster = dag_builder.graph.get_subgraphs()[0] + assert cluster.get("fontname") == "Times" + + @pytest.mark.unit + def test_add_node_with_attrs(self): + """Tests that default attributes are applied and can be overridden.""" + dag_builder = PyDotDAGBuilder( + node_attrs={"fillcolor": "lightblue", "penwidth": 3} + ) + + # Defaults + dag_builder.add_node("0", "node0") + node0 = dag_builder.graph.get_node("0")[0] + assert node0.get("fillcolor") == "lightblue" + assert node0.get("penwidth") == 3 + + # Make sure we can override + dag_builder.add_node("1", "node1", fillcolor="red", penwidth=4) + node1 = dag_builder.graph.get_node("1")[0] + assert node1.get("fillcolor") == "red" + assert node1.get("penwidth") == 4 + + @pytest.mark.unit + def test_add_edge_with_attrs(self): + """Tests that default attributes are applied and can be overridden.""" + dag_builder = PyDotDAGBuilder(edge_attrs={"color": "lightblue4", "penwidth": 3}) + + dag_builder.add_node("0", "node0") + dag_builder.add_node("1", "node1") + dag_builder.add_edge("0", "1") + edge = dag_builder.graph.get_edges()[0] + # Defaults defined earlier + assert edge.get("color") == "lightblue4" + assert edge.get("penwidth") == 3 + + # Make sure we can override + dag_builder.add_edge("0", "1", color="red", penwidth=4) + edge = dag_builder.graph.get_edges()[1] + assert edge.get("color") == "red" + assert edge.get("penwidth") == 4 + + @pytest.mark.unit + def test_add_cluster_with_attrs(self): + """Tests that default cluster attributes are applied and can be overridden.""" + dag_builder = PyDotDAGBuilder( + cluster_attrs={ + "style": "solid", + "fillcolor": None, + "penwidth": 2, + "fontname": "Helvetica", + } + ) + + dag_builder.add_cluster("0", "cluster0") + cluster1 = dag_builder.graph.get_subgraph("cluster_0")[0] + + # Defaults + assert cluster1.get("style") == "solid" + assert cluster1.get("fillcolor") is None + assert cluster1.get("penwidth") == 2 + assert cluster1.get("fontname") == "Helvetica" + + dag_builder.add_cluster( + "1", "cluster1", style="filled", penwidth=10, fillcolor="red" + ) + cluster2 = dag_builder.graph.get_subgraph("cluster_1")[0] + + # Make sure we can override + assert cluster2.get("style") == "filled" + assert cluster2.get("penwidth") == 10 + assert cluster2.get("fillcolor") == "red" + + # Check that other defaults are still present + assert cluster2.get("fontname") == "Helvetica" + + +class TestOutput: + """Test that the graph can be outputted correctly.""" + + @pytest.mark.unit + @pytest.mark.parametrize( + "filename, format", + [("my_graph", None), ("my_graph", "png"), ("prototype.trial1", "png")], + ) + def test_to_file(self, monkeypatch, filename, format): + """Tests that the `to_file` method works correctly.""" + dag_builder = PyDotDAGBuilder() + + # mock out the graph writing functionality + mock_write = MagicMock() + monkeypatch.setattr(dag_builder.graph, "write", mock_write) + dag_builder.to_file(filename + "." + (format or "png")) + + # make sure the function handles extensions correctly + mock_write.assert_called_once_with( + filename + "." + (format or "png"), format=format or "png" + ) + + @pytest.mark.unit + @pytest.mark.parametrize("format", ["pdf", "svg", "jpeg"]) + def test_other_supported_formats(self, monkeypatch, format): + """Tests that the `to_file` method works with other formats.""" + dag_builder = PyDotDAGBuilder() + + # mock out the graph writing functionality + mock_write = MagicMock() + monkeypatch.setattr(dag_builder.graph, "write", mock_write) + dag_builder.to_file(f"my_graph.{format}") + + # make sure the function handles extensions correctly + mock_write.assert_called_once_with(f"my_graph.{format}", format=format) + + @pytest.mark.unit + def test_to_string(self): + """Tests that the `to_string` method works correclty.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_node("n0", "node0") + dag_builder.add_node("n1", "node1") + dag_builder.add_edge("n0", "n1") + + string = dag_builder.to_string() + assert isinstance(string, str) + + # make sure important things show up in the string + assert "digraph" in string + assert "n0" in string + assert "n1" in string + assert "n0 -> n1" in string From e9ceaa7764b6f26befcea6e2340647e76ac077df Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 19 Nov 2025 13:33:20 -0500 Subject: [PATCH 002/111] add pydot to requirements.txt --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 08f57ee9c0..6b904318c6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -36,3 +36,4 @@ pennylane-lightning-kokkos amazon-braket-pennylane-plugin>1.27.1 xdsl xdsl-jax +pydot From ae50ca166c47fa9e3343c665da8144593c375734 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 19 Nov 2025 15:40:07 -0500 Subject: [PATCH 003/111] add files --- .../python_interface/visualization/mlir_dag_analysis_pass.py | 0 .../python_interface/visualization/test_mlir_dag_analysis_pass.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 frontend/catalyst/python_interface/visualization/mlir_dag_analysis_pass.py create mode 100644 frontend/test/pytest/python_interface/visualization/test_mlir_dag_analysis_pass.py diff --git a/frontend/catalyst/python_interface/visualization/mlir_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/mlir_dag_analysis_pass.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/frontend/test/pytest/python_interface/visualization/test_mlir_dag_analysis_pass.py b/frontend/test/pytest/python_interface/visualization/test_mlir_dag_analysis_pass.py new file mode 100644 index 0000000000..e69de29bb2 From a0e37b054ed3da037f049f41fdb8f9e436251aed Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 19 Nov 2025 15:41:57 -0500 Subject: [PATCH 004/111] rename --- .../{mlir_dag_analysis_pass.py => circuit_dag_analysis_pass.py} | 0 ...lir_dag_analysis_pass.py => test_circuit_dag_analysis_pass.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename frontend/catalyst/python_interface/visualization/{mlir_dag_analysis_pass.py => circuit_dag_analysis_pass.py} (100%) rename frontend/test/pytest/python_interface/visualization/{test_mlir_dag_analysis_pass.py => test_circuit_dag_analysis_pass.py} (100%) diff --git a/frontend/catalyst/python_interface/visualization/mlir_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py similarity index 100% rename from frontend/catalyst/python_interface/visualization/mlir_dag_analysis_pass.py rename to frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py diff --git a/frontend/test/pytest/python_interface/visualization/test_mlir_dag_analysis_pass.py b/frontend/test/pytest/python_interface/visualization/test_circuit_dag_analysis_pass.py similarity index 100% rename from frontend/test/pytest/python_interface/visualization/test_mlir_dag_analysis_pass.py rename to frontend/test/pytest/python_interface/visualization/test_circuit_dag_analysis_pass.py From 76b55653c656e1f4b6cef10714adc680079c51ad Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 19 Nov 2025 16:08:03 -0500 Subject: [PATCH 005/111] add base class --- .../circuit_dag_analysis_pass.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py index e69de29bb2..8baf52dd23 100644 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py @@ -0,0 +1,61 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains the CircuitDAGAnalysisPass for generating a DAG from an xDSL module.""" + +from functools import singledispatchmethod +from typing import TYPE_CHECKING, Any + +import xdsl +from xdsl.dialects import builtin, func, scf +from xdsl.ir import Block, Region + +from catalyst.python_interface.dialects import quantum + +if TYPE_CHECKING: + from catalyst.python_interface.visualization.dag_builder import DAGBuilder + + +class CircuitDAGAnalysisPass: + def __init__(self, dag_builder: DAGBuilder) -> None: + """Initialize the analysis pass.""" + self.dag_builder: DAGBuilder = dag_builder + + @singledispatchmethod + def visit_op(self, op: Any) -> None: + """Default handler for unknown operation types. + + This method is dispatched based on the type of 'op'. + + Args: + op (Any): An xDSL operation. + """ + + @visit_op.register + def visit_region(self, region: Region) -> None: + """Visit an xDSL Region operation.""" + for block in region.blocks: + self.visit_block(block) + + @visit_op.register + def visit_block(self, block: Block) -> None: + """Visit an xDSL Block operation.""" + for op in block.ops: + self.visit_op(op) + + def run(self, module: builtin.ModuleOp) -> None: + """Apply the analysis pass on the module.""" + + for op in module.ops: + self.visit_op(op) From 659f480894841b3069bb21947a6c8acb517ab498 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 19 Nov 2025 16:26:36 -0500 Subject: [PATCH 006/111] quick changes --- .../visualization/circuit_dag_analysis_pass.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py index 8baf52dd23..d439a422ca 100644 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py @@ -22,9 +22,7 @@ from xdsl.ir import Block, Region from catalyst.python_interface.dialects import quantum - -if TYPE_CHECKING: - from catalyst.python_interface.visualization.dag_builder import DAGBuilder +from catalyst.python_interface.visualization.dag_builder import DAGBuilder class CircuitDAGAnalysisPass: From 8699557dc2d25b37b98477947db5f4fef0cf911a Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 19 Nov 2025 16:42:33 -0500 Subject: [PATCH 007/111] add control flow support --- .../visualization/circuit_dag_analysis_pass.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py index d439a422ca..f5b63537cb 100644 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py @@ -40,6 +40,23 @@ def visit_op(self, op: Any) -> None: op (Any): An xDSL operation. """ + # ╔═══════════════════════════════════════════════════════════════╗ + # ║ CONTROL FLOW HANDLERS: Specialized dispatch for xDSL control  ║ + # ║                         flow operations (scf dialect).        ║ + # ╚═══════════════════════════════════════════════════════════════╝ + + @visit_op.register + def _visit_for_op(self, op: scf.ForOp) -> None: + """Handle an xDSL ForOp operation.""" + + @visit_op.register + def _visit_while_op(self, op: scf.WhileOp) -> None: + """Handle an xDSL WhileOp operation.""" + + @visit_op.register + def _visit_if_op(self, op: scf.IfOp) -> None: + """Handle an xDSL WhileOp operation.""" + @visit_op.register def visit_region(self, region: Region) -> None: """Visit an xDSL Region operation.""" From db6edc9f0397f220bdbbd15016cf707749cfe692 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 19 Nov 2025 17:04:15 -0500 Subject: [PATCH 008/111] clean-up --- .../circuit_dag_analysis_pass.py | 111 +++++++++++++----- 1 file changed, 83 insertions(+), 28 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py index f5b63537cb..bbe8fc8294 100644 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py @@ -26,51 +26,106 @@ class CircuitDAGAnalysisPass: - def __init__(self, dag_builder: DAGBuilder) -> None: - """Initialize the analysis pass.""" - self.dag_builder: DAGBuilder = dag_builder - - @singledispatchmethod - def visit_op(self, op: Any) -> None: - """Default handler for unknown operation types. + """A Pass that analyzes an xDSL module and constructs a Directed Acyclic Graph (DAG) + using an injected DAGBuilder instance. This is a non-mutating Analysis Pass.""" - This method is dispatched based on the type of 'op'. + def __init__(self, dag_builder: DAGBuilder) -> None: + """Initialize the analysis pass by injecting the DAG builder dependency. Args: - op (Any): An xDSL operation. + dag_builder (DAGBuilder): The concrete builder instance used for graph construction. """ + self.dag_builder: DAGBuilder = dag_builder - # ╔═══════════════════════════════════════════════════════════════╗ - # ║ CONTROL FLOW HANDLERS: Specialized dispatch for xDSL control  ║ - # ║                         flow operations (scf dialect).        ║ - # ╚═══════════════════════════════════════════════════════════════╝ + # ================================= + # 1. CORE DISPATCH AND ENTRY POINT + # ================================= - @visit_op.register - def _visit_for_op(self, op: scf.ForOp) -> None: - """Handle an xDSL ForOp operation.""" + @singledispatchmethod + def visit_op(self, op: Any) -> None: + """Central dispatch method (Visitor Pattern). Routes the operation 'op' + to the specialized handler registered for its type.""" + pass - @visit_op.register - def _visit_while_op(self, op: scf.WhileOp) -> None: - """Handle an xDSL WhileOp operation.""" + def run(self, module: builtin.ModuleOp) -> None: + """Applies the analysis pass on the module.""" + for op in module.ops: + self.visit_op(op) - @visit_op.register - def _visit_if_op(self, op: scf.IfOp) -> None: - """Handle an xDSL WhileOp operation.""" + # ======================= + # 2. HIERARCHY TRAVERSAL + # ======================= + # These methods navigate the recursive IR hierarchy (Op -> Region -> Block -> Op). @visit_op.register def visit_region(self, region: Region) -> None: - """Visit an xDSL Region operation.""" + """Visit an xDSL Region operation, delegating traversal to its Blocks.""" for block in region.blocks: self.visit_block(block) @visit_op.register def visit_block(self, block: Block) -> None: - """Visit an xDSL Block operation.""" + """Visit an xDSL Block operation, dispatching handling for each contained Operation.""" for op in block.ops: self.visit_op(op) - def run(self, module: builtin.ModuleOp) -> None: - """Apply the analysis pass on the module.""" + # ====================================== + # 3. QUANTUM GATE & STATE PREP HANDLERS + # ====================================== + # Handlers for operations that apply unitary transformations or set-up the quantum state. - for op in module.ops: - self.visit_op(op) + @visit_op.register + def _visit_unitary_and_state_prep( + self, + op: ( + quantum.CustomOp + | quantum.GlobalPhaseOp + | quantum.QubitUnitaryOp + | quantum.MultiRZOp + | quantum.SetStateOp + | quantum.SetBasisStateOp + ), + ) -> None: + """Generic handler for unitary gates and quantum state preparation operations.""" + pass + + # ============================================= + # 4. QUANTUM MEASUREMENT & OBSERVABLE HANDLERS + # ============================================= + + @visit_op.register + def _visit_terminal_state_op(self, op: quantum.StateOp) -> None: + """Handler for the terminal StateOp, which retrieves the final state vector.""" + pass + + @visit_op.register + def _visit_statistical_measurement_ops( + self, + op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp | quantum.SampleOp, + ) -> None: + """Handler for statistical measurement operations (e.g., Expval, Sample).""" + pass + + @visit_op.register + def _visit_projective_measure_op(self, op: quantum.MeasureOp) -> None: + """Handler for the single-qubit projective MeasureOp.""" + pass + + # ========================= + # 5. CONTROL FLOW HANDLERS + # ========================= + + @visit_op.register + def _visit_for_op(self, op: scf.ForOp) -> None: + """Handle an xDSL ForOp operation (Loop cluster creation).""" + pass + + @visit_op.register + def _visit_while_op(self, op: scf.WhileOp) -> None: + """Handle an xDSL WhileOp operation (Loop cluster creation).""" + pass + + @visit_op.register + def _visit_if_op(self, op: scf.IfOp) -> None: + """Handle an xDSL IfOp operation (Conditional cluster creation).""" + pass From 2be63e496425f9ab3d84cc4e5ed46a04492ebf6a Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 19 Nov 2025 19:20:20 -0500 Subject: [PATCH 009/111] Apply suggestion from @andrijapau --- .../python_interface/visualization/circuit_dag_analysis_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py index bbe8fc8294..bb93239dae 100644 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py @@ -94,7 +94,7 @@ def _visit_unitary_and_state_prep( # ============================================= @visit_op.register - def _visit_terminal_state_op(self, op: quantum.StateOp) -> None: + def _visit_state_op(self, op: quantum.StateOp) -> None: """Handler for the terminal StateOp, which retrieves the final state vector.""" pass From 12fa2afd08ee2f0e735c263bb864943d668259ec Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 19 Nov 2025 19:20:33 -0500 Subject: [PATCH 010/111] Apply suggestion from @andrijapau --- .../python_interface/visualization/circuit_dag_analysis_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py index bb93239dae..6e80816d76 100644 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py @@ -95,7 +95,7 @@ def _visit_unitary_and_state_prep( @visit_op.register def _visit_state_op(self, op: quantum.StateOp) -> None: - """Handler for the terminal StateOp, which retrieves the final state vector.""" + """Handler for the terminal StateOp.""" pass @visit_op.register From 2c933309a70ebd7be53060d42ff85660c156908a Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 19 Nov 2025 19:21:04 -0500 Subject: [PATCH 011/111] Apply suggestion from @andrijapau --- .../python_interface/visualization/circuit_dag_analysis_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py index 6e80816d76..891abf909b 100644 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py @@ -103,7 +103,7 @@ def _visit_statistical_measurement_ops( self, op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp | quantum.SampleOp, ) -> None: - """Handler for statistical measurement operations (e.g., Expval, Sample).""" + """Handler for statistical measurement operations.""" pass @visit_op.register From e68c3f434b7e0bf2be6f2607fad93a0c4718ce01 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 19 Nov 2025 19:21:34 -0500 Subject: [PATCH 012/111] Apply suggestion from @andrijapau --- .../visualization/circuit_dag_analysis_pass.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py index 891abf909b..d4369ff617 100644 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py @@ -117,15 +117,15 @@ def _visit_projective_measure_op(self, op: quantum.MeasureOp) -> None: @visit_op.register def _visit_for_op(self, op: scf.ForOp) -> None: - """Handle an xDSL ForOp operation (Loop cluster creation).""" + """Handle an xDSL ForOp operation.""" pass @visit_op.register def _visit_while_op(self, op: scf.WhileOp) -> None: - """Handle an xDSL WhileOp operation (Loop cluster creation).""" + """Handle an xDSL WhileOp operation.""" pass @visit_op.register def _visit_if_op(self, op: scf.IfOp) -> None: - """Handle an xDSL IfOp operation (Conditional cluster creation).""" + """Handle an xDSL IfOp operation.""" pass From 4aad864ec6dfc912ba2b2716be6304575fd7334b Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 19 Nov 2025 19:22:01 -0500 Subject: [PATCH 013/111] Apply suggestion from @andrijapau --- .../python_interface/visualization/circuit_dag_analysis_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py index d4369ff617..528ac54d95 100644 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py @@ -59,7 +59,7 @@ def run(self, module: builtin.ModuleOp) -> None: @visit_op.register def visit_region(self, region: Region) -> None: - """Visit an xDSL Region operation, delegating traversal to its Blocks.""" + """Visit an xDSL Region operation.""" for block in region.blocks: self.visit_block(block) From 3be39dcbfa0fcf99fbb90761844673d5a5d4fec0 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 11:23:29 -0500 Subject: [PATCH 014/111] update cluster label logic --- .../python_interface/visualization/dag_builder.py | 4 ++-- .../python_interface/visualization/pydot_dag_builder.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index 7a60a55f1e..5526ee5bf7 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -57,7 +57,7 @@ def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> Non def add_cluster( self, cluster_id: str, - cluster_label: str, + node_label: str | None = None, parent_graph_id: str | None = None, **cluster_attrs: Any, ) -> None: @@ -68,7 +68,7 @@ def add_cluster( Args: cluster_id (str): Unique cluster ID to identify this cluster. - cluster_label (str): The text to display on the cluster when rendered. + node_label (str): The text to display on an information node within the cluster when rendered. parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. **cluster_attrs (Any): Any additional styling keyword arguments. diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index ad9eef4bcb..f52e0fa834 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -128,7 +128,7 @@ def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> Non def add_cluster( self, cluster_id: str, - cluster_label: str, + node_label: str | None = None, parent_graph_id: str | None = None, **cluster_attrs: Any, ) -> None: @@ -139,7 +139,7 @@ def add_cluster( Args: cluster_id (str): Unique cluster ID to identify this cluster. - cluster_label (str): The text to display on the cluster when rendered. + node_label (str): The text to display on the information node on the cluster when rendered. parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. **cluster_attrs (Any): Any additional styling keyword arguments. @@ -157,12 +157,12 @@ def add_cluster( # │ └───────┘ │ # │ │ # └───────────┘ - if cluster_label: + if node_label: node_id = f"{cluster_id}_info_node" rank_subgraph = pydot.Subgraph() node = pydot.Node( node_id, - label=cluster_label, + label=node_label, shape="rectangle", style="dashed", fontname="Helvetica", From a383ba8fb26f655d8b86ccb4149651dda5627a0e Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 11:25:20 -0500 Subject: [PATCH 015/111] fix dag builders test --- .../python_interface/visualization/test_dag_builder.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py index 3f70ae8159..2e935bae1b 100644 --- a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py @@ -40,13 +40,15 @@ def add_node( ) -> None: return - def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> None: + def add_edge( + self, from_node_id: str, to_node_id: str, **edge_attrs: Any + ) -> None: return def add_cluster( self, cluster_id: str, - cluster_label: str, + node_label: str | None = None, parent_graph_id: str | None = None, **cluster_attrs: Any, ) -> None: @@ -62,7 +64,7 @@ def to_string(self) -> str: # pylint: disable = assignment-from-none node = dag_builder.add_node("0", "node0") edge = dag_builder.add_edge("0", "1") - cluster = dag_builder.add_cluster("0", "cluster0") + cluster = dag_builder.add_cluster("0") render = dag_builder.to_file("test.png") string = dag_builder.to_string() From 8ddaa5d36a7945e18233c18179474db6c2d63d7d Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 11:26:07 -0500 Subject: [PATCH 016/111] fix pydot dag builders test --- .../visualization/test_pydot_dag_builder.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 5c57975a12..11e5e4ac47 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -73,7 +73,7 @@ def test_add_cluster(self): """Unit test the 'add_cluster' method.""" dag_builder = PyDotDAGBuilder() - dag_builder.add_cluster("0", "my_cluster") + dag_builder.add_cluster("0") assert len(dag_builder.graph.get_subgraphs()) == 1 assert dag_builder.graph.get_subgraphs()[0].get_name() == "cluster_0" @@ -87,7 +87,7 @@ def test_add_node_to_parent_graph(self): dag_builder.add_node("0", "node0") # Create cluster - dag_builder.add_cluster("c0", "cluster0") + dag_builder.add_cluster("c0") # Create node inside cluster dag_builder.add_node("1", "node1", parent_graph_id="c0") @@ -117,11 +117,11 @@ def test_add_cluster_to_parent_graph(self): # Level 0 (Root): Adds cluster on top of base graph dag_builder.add_node("n_root", "node_root") - dag_builder.add_cluster("c0", "cluster_outer") + dag_builder.add_cluster("c0") # Level 1 (Inside c0): Add node on outer cluster and create new cluster on top dag_builder.add_node("n_outer", "node_outer", parent_graph_id="c0") - dag_builder.add_cluster("c1", "cluster_inner", parent_graph_id="c0") + dag_builder.add_cluster("c1", parent_graph_id="c0") # Level 2 (Inside c1): Add node on second cluster dag_builder.add_node("n_inner", "node_inner", parent_graph_id="c1") @@ -168,7 +168,7 @@ def test_default_graph_attrs(self): node0 = dag_builder.graph.get_node("0")[0] assert node0.get("fontname") == "Times" - dag_builder.add_cluster("1", "cluster0") + dag_builder.add_cluster("1") cluster = dag_builder.graph.get_subgraphs()[0] assert cluster.get("fontname") == "Times" @@ -222,7 +222,7 @@ def test_add_cluster_with_attrs(self): } ) - dag_builder.add_cluster("0", "cluster0") + dag_builder.add_cluster("0") cluster1 = dag_builder.graph.get_subgraph("cluster_0")[0] # Defaults @@ -232,7 +232,7 @@ def test_add_cluster_with_attrs(self): assert cluster1.get("fontname") == "Helvetica" dag_builder.add_cluster( - "1", "cluster1", style="filled", penwidth=10, fillcolor="red" + "1", style="filled", penwidth=10, fillcolor="red" ) cluster2 = dag_builder.graph.get_subgraph("cluster_1")[0] From 1166f40c8d4bb87e5e61c306b19fd5a6eb94aed9 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 11:27:41 -0500 Subject: [PATCH 017/111] update doc --- frontend/catalyst/python_interface/visualization/dag_builder.py | 2 +- .../python_interface/visualization/pydot_dag_builder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index 5526ee5bf7..aeb8b217b6 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -68,7 +68,7 @@ def add_cluster( Args: cluster_id (str): Unique cluster ID to identify this cluster. - node_label (str): The text to display on an information node within the cluster when rendered. + node_label (str | None): The text to display on an information node within the cluster when rendered. parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. **cluster_attrs (Any): Any additional styling keyword arguments. diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index f52e0fa834..cff97b0148 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -139,7 +139,7 @@ def add_cluster( Args: cluster_id (str): Unique cluster ID to identify this cluster. - node_label (str): The text to display on the information node on the cluster when rendered. + node_label (str | None): The text to display on the information node on the cluster when rendered. parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. **cluster_attrs (Any): Any additional styling keyword arguments. From df9bf45e6648a0127a2d4cb356a2c4f10e36a1a9 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 12:33:00 -0500 Subject: [PATCH 018/111] basic cl --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 831453990c..abb91a6bdd 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -3,7 +3,7 @@

New features since last release

* Compiled programs can be visualized. - [(#)]() + [(#2213)](https://github.com/PennyLaneAI/catalyst/pull/2213) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) From 8ca5b1fe925e03dbb405d1b0b657c3e87f773e84 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 12:33:35 -0500 Subject: [PATCH 019/111] basic cl --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index abb91a6bdd..3e29074a66 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,6 +4,7 @@ * Compiled programs can be visualized. [(#2213)](https://github.com/PennyLaneAI/catalyst/pull/2213) + [(#2214)](https://github.com/PennyLaneAI/catalyst/pull/2214) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) From b579a3fe57d9f20022a2a4387432b8a8a6802111 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 12:58:22 -0500 Subject: [PATCH 020/111] Trigger CI From a94d8432054cdab46bfb94702675f49f88b755b9 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 12:58:28 -0500 Subject: [PATCH 021/111] Trigger CI From 331e3491404c75d67a8cec30a287f4bb7f2b8618 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Thu, 20 Nov 2025 13:09:45 -0500 Subject: [PATCH 022/111] Apply suggestion from @andrijapau --- .../python_interface/visualization/circuit_dag_analysis_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py index 528ac54d95..df1a316e0a 100644 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py @@ -90,7 +90,7 @@ def _visit_unitary_and_state_prep( pass # ============================================= - # 4. QUANTUM MEASUREMENT & OBSERVABLE HANDLERS + # 4. QUANTUM MEASUREMENT HANDLERS # ============================================= @visit_op.register From a87d11744f3d8c307573e8f0b65f54784ae57cc4 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 13:12:41 -0500 Subject: [PATCH 023/111] just do customop --- .../visualization/circuit_dag_analysis_pass.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py index df1a316e0a..7fa5d71065 100644 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py @@ -77,14 +77,7 @@ def visit_block(self, block: Block) -> None: @visit_op.register def _visit_unitary_and_state_prep( self, - op: ( - quantum.CustomOp - | quantum.GlobalPhaseOp - | quantum.QubitUnitaryOp - | quantum.MultiRZOp - | quantum.SetStateOp - | quantum.SetBasisStateOp - ), + op: quantum.CustomOp, ) -> None: """Generic handler for unitary gates and quantum state preparation operations.""" pass From 9713aa3fa1c2150c88673c48dd5872f8afa4e906 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 13:15:22 -0500 Subject: [PATCH 024/111] fix wording --- .../visualization/circuit_dag_analysis_pass.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py index 7fa5d71065..bcc229798b 100644 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py @@ -45,7 +45,7 @@ def __init__(self, dag_builder: DAGBuilder) -> None: def visit_op(self, op: Any) -> None: """Central dispatch method (Visitor Pattern). Routes the operation 'op' to the specialized handler registered for its type.""" - pass + raise NotImplementedError(f"Dispatch not registered for operator of type {type(op)}") def run(self, module: builtin.ModuleOp) -> None: """Applies the analysis pass on the module.""" @@ -88,7 +88,7 @@ def _visit_unitary_and_state_prep( @visit_op.register def _visit_state_op(self, op: quantum.StateOp) -> None: - """Handler for the terminal StateOp.""" + """Handler for the terminal state measurement operation.""" pass @visit_op.register @@ -101,7 +101,7 @@ def _visit_statistical_measurement_ops( @visit_op.register def _visit_projective_measure_op(self, op: quantum.MeasureOp) -> None: - """Handler for the single-qubit projective MeasureOp.""" + """Handler for the single-qubit projective measurement operation.""" pass # ========================= From 7bbc22410d50c89b83d6e81aaa4c92ac340ec924 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 14:13:24 -0500 Subject: [PATCH 025/111] rename --- .../circuit_dag_analysis_pass.py | 124 ------------------ .../test_circuit_dag_analysis_pass.py | 0 2 files changed, 124 deletions(-) delete mode 100644 frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py delete mode 100644 frontend/test/pytest/python_interface/visualization/test_circuit_dag_analysis_pass.py diff --git a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py b/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py deleted file mode 100644 index bcc229798b..0000000000 --- a/frontend/catalyst/python_interface/visualization/circuit_dag_analysis_pass.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright 2025 Xanadu Quantum Technologies Inc. - -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at - -# http://www.apache.org/licenses/LICENSE-2.0 - -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Contains the CircuitDAGAnalysisPass for generating a DAG from an xDSL module.""" - -from functools import singledispatchmethod -from typing import TYPE_CHECKING, Any - -import xdsl -from xdsl.dialects import builtin, func, scf -from xdsl.ir import Block, Region - -from catalyst.python_interface.dialects import quantum -from catalyst.python_interface.visualization.dag_builder import DAGBuilder - - -class CircuitDAGAnalysisPass: - """A Pass that analyzes an xDSL module and constructs a Directed Acyclic Graph (DAG) - using an injected DAGBuilder instance. This is a non-mutating Analysis Pass.""" - - def __init__(self, dag_builder: DAGBuilder) -> None: - """Initialize the analysis pass by injecting the DAG builder dependency. - - Args: - dag_builder (DAGBuilder): The concrete builder instance used for graph construction. - """ - self.dag_builder: DAGBuilder = dag_builder - - # ================================= - # 1. CORE DISPATCH AND ENTRY POINT - # ================================= - - @singledispatchmethod - def visit_op(self, op: Any) -> None: - """Central dispatch method (Visitor Pattern). Routes the operation 'op' - to the specialized handler registered for its type.""" - raise NotImplementedError(f"Dispatch not registered for operator of type {type(op)}") - - def run(self, module: builtin.ModuleOp) -> None: - """Applies the analysis pass on the module.""" - for op in module.ops: - self.visit_op(op) - - # ======================= - # 2. HIERARCHY TRAVERSAL - # ======================= - # These methods navigate the recursive IR hierarchy (Op -> Region -> Block -> Op). - - @visit_op.register - def visit_region(self, region: Region) -> None: - """Visit an xDSL Region operation.""" - for block in region.blocks: - self.visit_block(block) - - @visit_op.register - def visit_block(self, block: Block) -> None: - """Visit an xDSL Block operation, dispatching handling for each contained Operation.""" - for op in block.ops: - self.visit_op(op) - - # ====================================== - # 3. QUANTUM GATE & STATE PREP HANDLERS - # ====================================== - # Handlers for operations that apply unitary transformations or set-up the quantum state. - - @visit_op.register - def _visit_unitary_and_state_prep( - self, - op: quantum.CustomOp, - ) -> None: - """Generic handler for unitary gates and quantum state preparation operations.""" - pass - - # ============================================= - # 4. QUANTUM MEASUREMENT HANDLERS - # ============================================= - - @visit_op.register - def _visit_state_op(self, op: quantum.StateOp) -> None: - """Handler for the terminal state measurement operation.""" - pass - - @visit_op.register - def _visit_statistical_measurement_ops( - self, - op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp | quantum.SampleOp, - ) -> None: - """Handler for statistical measurement operations.""" - pass - - @visit_op.register - def _visit_projective_measure_op(self, op: quantum.MeasureOp) -> None: - """Handler for the single-qubit projective measurement operation.""" - pass - - # ========================= - # 5. CONTROL FLOW HANDLERS - # ========================= - - @visit_op.register - def _visit_for_op(self, op: scf.ForOp) -> None: - """Handle an xDSL ForOp operation.""" - pass - - @visit_op.register - def _visit_while_op(self, op: scf.WhileOp) -> None: - """Handle an xDSL WhileOp operation.""" - pass - - @visit_op.register - def _visit_if_op(self, op: scf.IfOp) -> None: - """Handle an xDSL IfOp operation.""" - pass diff --git a/frontend/test/pytest/python_interface/visualization/test_circuit_dag_analysis_pass.py b/frontend/test/pytest/python_interface/visualization/test_circuit_dag_analysis_pass.py deleted file mode 100644 index e69de29bb2..0000000000 From d6db96550c425f8f0b5e33e4f6484b8077fc7e82 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 14:14:07 -0500 Subject: [PATCH 026/111] add back --- .../visualization/construct_circuit_dag.py | 124 ++++++++++++++++++ .../test_construct_circuit_dag.py | 0 2 files changed, 124 insertions(+) create mode 100644 frontend/catalyst/python_interface/visualization/construct_circuit_dag.py create mode 100644 frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py new file mode 100644 index 0000000000..1237f94fed --- /dev/null +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -0,0 +1,124 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" + +from functools import singledispatchmethod +from typing import TYPE_CHECKING, Any + +import xdsl +from xdsl.dialects import builtin, func, scf +from xdsl.ir import Block, Region + +from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.visualization.dag_builder import DAGBuilder + + +class ConstructCircuitDAG: + """A tool that analyzes an xDSL module and constructs a Directed Acyclic Graph (DAG) + using an injected DAGBuilder instance. This tool does not mutate the xDSL module.""" + + def __init__(self, dag_builder: DAGBuilder) -> None: + """Initialize the analysis pass by injecting the DAG builder dependency. + + Args: + dag_builder (DAGBuilder): The concrete builder instance used for graph construction. + """ + self.dag_builder: DAGBuilder = dag_builder + + # ================================= + # 1. CORE DISPATCH AND ENTRY POINT + # ================================= + + @singledispatchmethod + def visit_op(self, op: Any) -> None: + """Central dispatch method (Visitor Pattern). Routes the operation 'op' + to the specialized handler registered for its type.""" + raise NotImplementedError(f"Dispatch not registered for operator of type {type(op)}") + + def construct(self, module: builtin.ModuleOp) -> None: + """Constructs the DAG from the module.""" + for op in module.ops: + self.visit_op(op) + + # ======================= + # 2. HIERARCHY TRAVERSAL + # ======================= + # These methods navigate the recursive IR hierarchy (Op -> Region -> Block -> Op). + + @visit_op.register + def visit_region(self, region: Region) -> None: + """Visit an xDSL Region operation.""" + for block in region.blocks: + self.visit_block(block) + + @visit_op.register + def visit_block(self, block: Block) -> None: + """Visit an xDSL Block operation, dispatching handling for each contained Operation.""" + for op in block.ops: + self.visit_op(op) + + # ====================================== + # 3. QUANTUM GATE & STATE PREP HANDLERS + # ====================================== + # Handlers for operations that apply unitary transformations or set-up the quantum state. + + @visit_op.register + def _visit_unitary_and_state_prep( + self, + op: quantum.CustomOp, + ) -> None: + """Generic handler for unitary gates and quantum state preparation operations.""" + pass + + # ============================================= + # 4. QUANTUM MEASUREMENT HANDLERS + # ============================================= + + @visit_op.register + def _visit_state_op(self, op: quantum.StateOp) -> None: + """Handler for the terminal state measurement operation.""" + pass + + @visit_op.register + def _visit_statistical_measurement_ops( + self, + op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp | quantum.SampleOp, + ) -> None: + """Handler for statistical measurement operations.""" + pass + + @visit_op.register + def _visit_projective_measure_op(self, op: quantum.MeasureOp) -> None: + """Handler for the single-qubit projective measurement operation.""" + pass + + # ========================= + # 5. CONTROL FLOW HANDLERS + # ========================= + + @visit_op.register + def _visit_for_op(self, op: scf.ForOp) -> None: + """Handle an xDSL ForOp operation.""" + pass + + @visit_op.register + def _visit_while_op(self, op: scf.WhileOp) -> None: + """Handle an xDSL WhileOp operation.""" + pass + + @visit_op.register + def _visit_if_op(self, op: scf.IfOp) -> None: + """Handle an xDSL IfOp operation.""" + pass diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py new file mode 100644 index 0000000000..e69de29bb2 From 11b08f026682398898550161c2a31f6a3bac6015 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:53:22 -0500 Subject: [PATCH 027/111] Update frontend/catalyst/python_interface/visualization/pydot_dag_builder.py Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- .../python_interface/visualization/pydot_dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index cff97b0148..4c7c40a579 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -49,7 +49,7 @@ def __init__( # - graph_type="digraph": Create a directed graph (edges have arrows). # - rankdir="TB": Set layout direction from Top to Bottom. # - compound="true": Allow edges to connect directly to clusters/subgraphs. - # - strict=True": Prevent duplicate edges (e.g., A -> B added twice). + # - strict=True: Prevent duplicate edges (e.g., A -> B added twice). self.graph: pydot.Dot = pydot.Dot( graph_type="digraph", rankdir="TB", compound="true", strict=True ) From a76b5bdfc4efb9f7fef884255d60c757f0803e4e Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:53:27 -0500 Subject: [PATCH 028/111] Update frontend/catalyst/python_interface/visualization/pydot_dag_builder.py Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- .../python_interface/visualization/pydot_dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 4c7c40a579..384537c93e 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -139,7 +139,7 @@ def add_cluster( Args: cluster_id (str): Unique cluster ID to identify this cluster. - node_label (str | None): The text to display on the information node on the cluster when rendered. + node_label (str | None): The text to display on the information node within the cluster when rendered. parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. **cluster_attrs (Any): Any additional styling keyword arguments. From 90e4ebd4e7281deab5ad1615e94474384c777ec5 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 15:16:24 -0500 Subject: [PATCH 029/111] fix --- .../visualization/construct_circuit_dag.py | 48 +++++++++++-------- 1 file changed, 27 insertions(+), 21 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 1237f94fed..27b7a2c631 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -19,7 +19,7 @@ import xdsl from xdsl.dialects import builtin, func, scf -from xdsl.ir import Block, Region +from xdsl.ir import Block, Operation, Region from catalyst.python_interface.dialects import quantum from catalyst.python_interface.visualization.dag_builder import DAGBuilder @@ -42,40 +42,46 @@ def __init__(self, dag_builder: DAGBuilder) -> None: # ================================= @singledispatchmethod - def visit_op(self, op: Any) -> None: + def visit(self, op: Any) -> None: """Central dispatch method (Visitor Pattern). Routes the operation 'op' to the specialized handler registered for its type.""" - raise NotImplementedError(f"Dispatch not registered for operator of type {type(op)}") + pass def construct(self, module: builtin.ModuleOp) -> None: """Constructs the DAG from the module.""" for op in module.ops: - self.visit_op(op) + self.visit(op) # ======================= # 2. HIERARCHY TRAVERSAL # ======================= # These methods navigate the recursive IR hierarchy (Op -> Region -> Block -> Op). - @visit_op.register + @visit.register + def visit_operation(self, operation: Operation) -> None: + """Visit an xDSL Operation.""" + for region in operation.regions: + self.visit_region(region) + + @visit.register def visit_region(self, region: Region) -> None: """Visit an xDSL Region operation.""" for block in region.blocks: self.visit_block(block) - @visit_op.register + @visit.register def visit_block(self, block: Block) -> None: """Visit an xDSL Block operation, dispatching handling for each contained Operation.""" for op in block.ops: - self.visit_op(op) + self.visit(op) # ====================================== # 3. QUANTUM GATE & STATE PREP HANDLERS # ====================================== # Handlers for operations that apply unitary transformations or set-up the quantum state. - @visit_op.register - def _visit_unitary_and_state_prep( + @visit.register + def _unitary_and_state_prep( self, op: quantum.CustomOp, ) -> None: @@ -86,21 +92,21 @@ def _visit_unitary_and_state_prep( # 4. QUANTUM MEASUREMENT HANDLERS # ============================================= - @visit_op.register - def _visit_state_op(self, op: quantum.StateOp) -> None: + @visit.register + def _state_op(self, op: quantum.StateOp) -> None: """Handler for the terminal state measurement operation.""" pass - @visit_op.register - def _visit_statistical_measurement_ops( + @visit.register + def _statistical_measurement_ops( self, op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp | quantum.SampleOp, ) -> None: """Handler for statistical measurement operations.""" pass - @visit_op.register - def _visit_projective_measure_op(self, op: quantum.MeasureOp) -> None: + @visit.register + def _projective_measure_op(self, op: quantum.MeasureOp) -> None: """Handler for the single-qubit projective measurement operation.""" pass @@ -108,17 +114,17 @@ def _visit_projective_measure_op(self, op: quantum.MeasureOp) -> None: # 5. CONTROL FLOW HANDLERS # ========================= - @visit_op.register - def _visit_for_op(self, op: scf.ForOp) -> None: + @visit.register + def _for_op(self, op: scf.ForOp) -> None: """Handle an xDSL ForOp operation.""" pass - @visit_op.register - def _visit_while_op(self, op: scf.WhileOp) -> None: + @visit.register + def _while_op(self, op: scf.WhileOp) -> None: """Handle an xDSL WhileOp operation.""" pass - @visit_op.register - def _visit_if_op(self, op: scf.IfOp) -> None: + @visit.register + def _if_op(self, op: scf.IfOp) -> None: """Handle an xDSL IfOp operation.""" pass From 3d9e4bb25387631a0adc943958110ef4f2efc9ff Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 15:16:54 -0500 Subject: [PATCH 030/111] fix --- .../python_interface/visualization/construct_circuit_dag.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 27b7a2c631..3150732024 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -15,10 +15,9 @@ """Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" from functools import singledispatchmethod -from typing import TYPE_CHECKING, Any +from typing import Any -import xdsl -from xdsl.dialects import builtin, func, scf +from xdsl.dialects import builtin, scf from xdsl.ir import Block, Operation, Region from catalyst.python_interface.dialects import quantum From 2504c1304406ff6536f9ab012865ab40f1e542c8 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 10:21:39 -0500 Subject: [PATCH 031/111] clean-up --- .../visualization/construct_circuit_dag.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 3150732024..9ab65459cb 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -57,19 +57,19 @@ def construct(self, module: builtin.ModuleOp) -> None: # These methods navigate the recursive IR hierarchy (Op -> Region -> Block -> Op). @visit.register - def visit_operation(self, operation: Operation) -> None: + def _operation(self, operation: Operation) -> None: """Visit an xDSL Operation.""" for region in operation.regions: - self.visit_region(region) + self.visit(region) @visit.register - def visit_region(self, region: Region) -> None: + def _region(self, region: Region) -> None: """Visit an xDSL Region operation.""" for block in region.blocks: - self.visit_block(block) + self.visit(block) @visit.register - def visit_block(self, block: Block) -> None: + def _block(self, block: Block) -> None: """Visit an xDSL Block operation, dispatching handling for each contained Operation.""" for op in block.ops: self.visit(op) From 4e78777df34b15e36583c1585cbecb935fa57928 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 10:58:46 -0500 Subject: [PATCH 032/111] remove unnecessary stuff --- .../visualization/construct_circuit_dag.py | 54 ------------------- 1 file changed, 54 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 9ab65459cb..85b5c85f0c 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -73,57 +73,3 @@ def _block(self, block: Block) -> None: """Visit an xDSL Block operation, dispatching handling for each contained Operation.""" for op in block.ops: self.visit(op) - - # ====================================== - # 3. QUANTUM GATE & STATE PREP HANDLERS - # ====================================== - # Handlers for operations that apply unitary transformations or set-up the quantum state. - - @visit.register - def _unitary_and_state_prep( - self, - op: quantum.CustomOp, - ) -> None: - """Generic handler for unitary gates and quantum state preparation operations.""" - pass - - # ============================================= - # 4. QUANTUM MEASUREMENT HANDLERS - # ============================================= - - @visit.register - def _state_op(self, op: quantum.StateOp) -> None: - """Handler for the terminal state measurement operation.""" - pass - - @visit.register - def _statistical_measurement_ops( - self, - op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp | quantum.SampleOp, - ) -> None: - """Handler for statistical measurement operations.""" - pass - - @visit.register - def _projective_measure_op(self, op: quantum.MeasureOp) -> None: - """Handler for the single-qubit projective measurement operation.""" - pass - - # ========================= - # 5. CONTROL FLOW HANDLERS - # ========================= - - @visit.register - def _for_op(self, op: scf.ForOp) -> None: - """Handle an xDSL ForOp operation.""" - pass - - @visit.register - def _while_op(self, op: scf.WhileOp) -> None: - """Handle an xDSL WhileOp operation.""" - pass - - @visit.register - def _if_op(self, op: scf.IfOp) -> None: - """Handle an xDSL IfOp operation.""" - pass From e17c1112364a6056d0cc2a739b998ab69b48f2dc Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 11:07:29 -0500 Subject: [PATCH 033/111] add test skeleton --- .../test_construct_circuit_dag.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index e69de29bb2..336df14c08 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -0,0 +1,24 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for the ConstructCircuitDAG utility.""" + +import pytest + +pytestmark = pytest.mark.usefixtures("requires_xdsl") + +# pylint: disable=wrong-import-position +# This import needs to be after pytest in order to prevent ImportErrors +from catalyst.python_interface.visualization.construct_circuit_dag import ConstructCircuitDAG +from catalyst.python_interface.visualization.dag_builder import DAGBuilder + From 19ca1b5e11ec95a08552fd92aca8be87200d8fff Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 11:28:28 -0500 Subject: [PATCH 034/111] add basic tests --- .../test_construct_circuit_dag.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 336df14c08..3be7ac78f5 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -13,12 +13,34 @@ # limitations under the License. """Unit tests for the ConstructCircuitDAG utility.""" +from unittest.mock import MagicMock, Mock + import pytest pytestmark = pytest.mark.usefixtures("requires_xdsl") + # pylint: disable=wrong-import-position # This import needs to be after pytest in order to prevent ImportErrors -from catalyst.python_interface.visualization.construct_circuit_dag import ConstructCircuitDAG -from catalyst.python_interface.visualization.dag_builder import DAGBuilder +from catalyst.python_interface.visualization.construct_circuit_dag import ( + ConstructCircuitDAG, +) +from catalyst.python_interface.visualization.dag_builder import DAGBuilder + + +class TestInitialization: + """Tests that the state is correctly initialized.""" + + def test_dependency_injection(self): + """Tests that relevant dependencies are injected.""" + + mock_dag_builder = Mock(DAGBuilder) + utility = ConstructCircuitDAG(mock_dag_builder) + assert utility.dag_builder is mock_dag_builder + + +class TestRecursiveTraversal: + """Tests that the recursive traversal logic works correctly.""" + def test_entire_module_is_traversed(self): + """Tests that the entire module heirarchy is traversed correctly.""" From 7f88834e3164da84cbbf67d933a2a3777b9a050d Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 13:27:05 -0500 Subject: [PATCH 035/111] basic test idea --- .../test_construct_circuit_dag.py | 41 +++++++++++++++++-- 1 file changed, 38 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 3be7ac78f5..bad7d9c947 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -13,19 +13,22 @@ # limitations under the License. """Unit tests for the ConstructCircuitDAG utility.""" -from unittest.mock import MagicMock, Mock +from unittest import mock +from unittest.mock import MagicMock, Mock, call import pytest pytestmark = pytest.mark.usefixtures("requires_xdsl") - # pylint: disable=wrong-import-position # This import needs to be after pytest in order to prevent ImportErrors from catalyst.python_interface.visualization.construct_circuit_dag import ( ConstructCircuitDAG, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder +from xdsl.dialects import test +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir.core import Block, Region class TestInitialization: @@ -43,4 +46,36 @@ class TestRecursiveTraversal: """Tests that the recursive traversal logic works correctly.""" def test_entire_module_is_traversed(self): - """Tests that the entire module heirarchy is traversed correctly.""" + """Tests that the entire module hierarchy is traversed correctly.""" + + # Create block containing some ops + op = test.TestOp() + block = Block(ops=[op, op]) + # Create region containing some blocks + region = Region(blocks=[block, block]) + # Create op containing the regions + container_op = test.TestOp(regions=[region, region]) + # Create module op to house it all + module_op = ModuleOp(ops=[container_op]) + + mock_dag_builder = Mock(DAGBuilder) + utility = ConstructCircuitDAG(mock_dag_builder) + + # Mock out the visit dispatcher + utility.visit = Mock() + + utility.construct(module_op) + + assert utility.visit.call_count == 5 + + expected_calls = [ + call(container_op), + call(region), + call(region), + call(block), + call(block), + call(op), + call(op), + ] + + utility.visit.assert_has_calls(expected_calls, any_order=False) From 3d48c9291739844e7b0783d2837214e81aed5a8a Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 13:43:07 -0500 Subject: [PATCH 036/111] make visit private --- .../visualization/construct_circuit_dag.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 85b5c85f0c..f662e1b7b3 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -41,7 +41,7 @@ def __init__(self, dag_builder: DAGBuilder) -> None: # ================================= @singledispatchmethod - def visit(self, op: Any) -> None: + def _visit(self, op: Any) -> None: """Central dispatch method (Visitor Pattern). Routes the operation 'op' to the specialized handler registered for its type.""" pass @@ -49,27 +49,27 @@ def visit(self, op: Any) -> None: def construct(self, module: builtin.ModuleOp) -> None: """Constructs the DAG from the module.""" for op in module.ops: - self.visit(op) + self._visit(op) # ======================= # 2. HIERARCHY TRAVERSAL # ======================= # These methods navigate the recursive IR hierarchy (Op -> Region -> Block -> Op). - @visit.register + @_visit.register def _operation(self, operation: Operation) -> None: """Visit an xDSL Operation.""" for region in operation.regions: - self.visit(region) + self._visit(region) - @visit.register + @_visit.register def _region(self, region: Region) -> None: """Visit an xDSL Region operation.""" for block in region.blocks: - self.visit(block) + self._visit(block) - @visit.register + @_visit.register def _block(self, block: Block) -> None: """Visit an xDSL Block operation, dispatching handling for each contained Operation.""" for op in block.ops: - self.visit(op) + self._visit(op) From 49b3834a37b18b113149f15d1b95d5994d33a645 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 13:47:34 -0500 Subject: [PATCH 037/111] make visit private --- .../visualization/test_construct_circuit_dag.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index bad7d9c947..88e638ee31 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -62,11 +62,11 @@ def test_entire_module_is_traversed(self): utility = ConstructCircuitDAG(mock_dag_builder) # Mock out the visit dispatcher - utility.visit = Mock() + utility._visit = Mock() utility.construct(module_op) - assert utility.visit.call_count == 5 + assert utility._visit.call_count == 7 expected_calls = [ call(container_op), @@ -78,4 +78,4 @@ def test_entire_module_is_traversed(self): call(op), ] - utility.visit.assert_has_calls(expected_calls, any_order=False) + utility._visit.assert_has_calls(expected_calls, any_order=False) From 80ca59d71df779e98086f4309b22d8b6c50e9011 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 14:12:23 -0500 Subject: [PATCH 038/111] fix tests --- .../test_construct_circuit_dag.py | 23 ++++--------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 88e638ee31..39a74578bb 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -50,32 +50,17 @@ def test_entire_module_is_traversed(self): # Create block containing some ops op = test.TestOp() - block = Block(ops=[op, op]) + block = Block(ops=[op]) # Create region containing some blocks - region = Region(blocks=[block, block]) + region = Region(blocks=[block]) # Create op containing the regions - container_op = test.TestOp(regions=[region, region]) + container_op = test.TestOp(regions=[region]) # Create module op to house it all module_op = ModuleOp(ops=[container_op]) mock_dag_builder = Mock(DAGBuilder) utility = ConstructCircuitDAG(mock_dag_builder) - # Mock out the visit dispatcher - utility._visit = Mock() - utility.construct(module_op) - assert utility._visit.call_count == 7 - - expected_calls = [ - call(container_op), - call(region), - call(region), - call(block), - call(block), - call(op), - call(op), - ] - - utility._visit.assert_has_calls(expected_calls, any_order=False) + # Assert visit was dispatched correct number of times with the correct inputs From 430ceb810fc0015182b624b93d77ba1cc03955b4 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 19:37:29 -0500 Subject: [PATCH 039/111] fix tests --- .../test_construct_circuit_dag.py | 42 +++++++++---------- 1 file changed, 20 insertions(+), 22 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 39a74578bb..6d4b7106ea 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -42,25 +42,23 @@ def test_dependency_injection(self): assert utility.dag_builder is mock_dag_builder -class TestRecursiveTraversal: - """Tests that the recursive traversal logic works correctly.""" - - def test_entire_module_is_traversed(self): - """Tests that the entire module hierarchy is traversed correctly.""" - - # Create block containing some ops - op = test.TestOp() - block = Block(ops=[op]) - # Create region containing some blocks - region = Region(blocks=[block]) - # Create op containing the regions - container_op = test.TestOp(regions=[region]) - # Create module op to house it all - module_op = ModuleOp(ops=[container_op]) - - mock_dag_builder = Mock(DAGBuilder) - utility = ConstructCircuitDAG(mock_dag_builder) - - utility.construct(module_op) - - # Assert visit was dispatched correct number of times with the correct inputs +def test_does_not_mutate_module(): + """Test that the module is not mutated.""" + + # Create block containing some ops + op = test.TestOp() + block = Block(ops=[op]) + # Create region containing some blocks + region = Region(blocks=[block]) + # Create op containing the regions + container_op = test.TestOp(regions=[region]) + # Create module op to house it all + module_op = ModuleOp(ops=[container_op]) + + module_op_str_before = str(module_op) + + mock_dag_builder = Mock(DAGBuilder) + utility = ConstructCircuitDAG(mock_dag_builder) + utility.construct(module_op) + + assert str(module_op) == module_op_str_before From 0aff4e5cc57bede900876d6c413dd1f0b7387c2f Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 10:08:55 -0500 Subject: [PATCH 040/111] clean-up --- .../visualization/test_construct_circuit_dag.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 6d4b7106ea..5a55cd0ba7 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -13,8 +13,7 @@ # limitations under the License. """Unit tests for the ConstructCircuitDAG utility.""" -from unittest import mock -from unittest.mock import MagicMock, Mock, call +from unittest.mock import Mock import pytest From e22cbd2499f0706bccb639f3112d3e4480693d1e Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 13:08:27 -0500 Subject: [PATCH 041/111] fix: upgrade DAG builders to have get_ methods --- .../visualization/dag_builder.py | 46 +++++++++++++++++-- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index aeb8b217b6..d59a81b6f4 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -14,7 +14,10 @@ """File that defines the DAGBuilder abstract base class.""" from abc import ABC, abstractmethod -from typing import Any +from typing import Any, TypeAlias + +ClusterID: TypeAlias = str +NodeID: TypeAlias = str class DAGBuilder(ABC): @@ -28,7 +31,11 @@ class DAGBuilder(ABC): @abstractmethod def add_node( - self, node_id: str, node_label: str, parent_graph_id: str | None = None, **node_attrs: Any + self, + node_id: NodeID, + node_label: str, + parent_graph_id: str | None = None, + **node_attrs: Any, ) -> None: """Add a single node to the graph. @@ -42,7 +49,9 @@ def add_node( raise NotImplementedError @abstractmethod - def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> None: + def add_edge( + self, from_node_id: NodeID, to_node_id: NodeID, **edge_attrs: Any + ) -> None: """Add a single directed edge between nodes in the graph. Args: @@ -56,9 +65,9 @@ def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> Non @abstractmethod def add_cluster( self, - cluster_id: str, + cluster_id: ClusterID, node_label: str | None = None, - parent_graph_id: str | None = None, + parent_graph_id: ClusterID | None = None, **cluster_attrs: Any, ) -> None: """Add a single cluster to the graph. @@ -75,6 +84,33 @@ def add_cluster( """ raise NotImplementedError + @abstractmethod + def get_nodes(self) -> dict[NodeID, dict[str, Any]]: + """Retrieve the current set of nodes in the graph. + + Returns: + nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to it's node information. + """ + raise NotImplementedError + + @abstractmethod + def get_edges(self) -> list[dict[str, Any]]: + """Retrieve the current set of edges in the graph. + + Returns: + edges (list[dict[str, Any]]): A list of edges where each edge contains a dictionary of information for a given edge. + """ + raise NotImplementedError + + @abstractmethod + def get_clusters(self) -> dict[ClusterID, dict[str, Any]]: + """Retrieve the current set of clusters in the graph. + + Returns: + clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to it's cluster information. + """ + raise NotImplementedError + @abstractmethod def to_file(self, output_filename: str) -> None: """Save the graph to a file. From acf3da79912c7a0fb5bd493319fb6855eca70ebb Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 13:12:48 -0500 Subject: [PATCH 042/111] cl --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index ee7c22c7e3..6e8a6465ff 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,6 +4,7 @@ * Compiled programs can be visualized. [(#2213)](https://github.com/PennyLaneAI/catalyst/pull/2213) + [(#2229)](https://github.com/PennyLaneAI/catalyst/pull/2229) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) From bc978adad3e3c51da154549bd24824314860ccfc Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 13:24:25 -0500 Subject: [PATCH 043/111] update pydot to adhere to new base class methods --- .../visualization/pydot_dag_builder.py | 52 ++++++++++++++++++- 1 file changed, 51 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 384537c93e..6f756bf40c 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -57,7 +57,14 @@ def __init__( self._subgraphs: dict[str, pydot.Graph] = {} self._subgraphs["__base__"] = self.graph - _default_attrs: dict = {"fontname": "Helvetica", "penwidth": 2} if attrs is None else attrs + # Internal state for graph structure + self._nodes: dict[str, dict[str, Any]] = {} + self._edges: list[dict[str, Any]] = [] + self._clusters: dict[str, dict[str, Any]] = {} + + _default_attrs: dict = ( + {"fontname": "Helvetica", "penwidth": 2} if attrs is None else attrs + ) self._default_node_attrs: dict = ( { **_default_attrs, @@ -111,6 +118,13 @@ def add_node( self._subgraphs[parent_graph_id].add_node(node) + self._nodes[node_id] = { + "id": node_id, + "label": node_label, + "parent_id": parent_graph_id, + "attrs": dict(node_attrs), + } + def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> None: """Add a single directed edge between nodes in the graph. @@ -125,6 +139,10 @@ def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> Non edge = pydot.Edge(from_node_id, to_node_id, **edge_attrs) self.graph.add_edge(edge) + self._edges.append( + {"from_id": from_node_id, "to_id": to_node_id, "attrs": dict(edge_attrs)} + ) + def add_cluster( self, cluster_id: str, @@ -177,6 +195,38 @@ def add_cluster( parent_graph_id = "__base__" if parent_graph_id is None else parent_graph_id self._subgraphs[parent_graph_id].add_subgraph(cluster) + self._clusters[cluster_id] = { + "id": cluster_id, + "cluster_label": cluster_attrs.get("label"), + "node_label": node_label, + "parent_id": parent_graph_id, + "attrs": dict(cluster_attrs), + } + + def get_nodes(self) -> dict[str, dict[str, Any]]: + """Retrieve the current set of nodes in the graph. + + Returns: + nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to it's node information. + """ + return self._nodes + + def get_edges(self) -> list[dict[str, Any]]: + """Retrieve the current set of edges in the graph. + + Returns: + edges (list[dict[str, Any]]): A list of edges where each edge contains a dictionary of information for a given edge. + """ + return self._edges + + def get_clusters(self) -> dict[str, dict[str, Any]]: + """Retrieve the current set of clusters in the graph. + + Returns: + clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to it's cluster information. + """ + return self._clusters + def to_file(self, output_filename: str) -> None: """Save the graph to a file. From dc4cb1c622ccddba3940e4f3a9e8658dcbebb983 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 13:37:21 -0500 Subject: [PATCH 044/111] add test skeletons --- .../visualization/pydot_dag_builder.py | 6 +++--- .../visualization/test_dag_builder.py | 15 +++++++++++++++ .../visualization/test_pydot_dag_builder.py | 17 ++++++++++++++--- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 6f756bf40c..9a81dc55d0 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -112,7 +112,7 @@ def add_node( """ # Use ChainMap so you don't need to construct a new dictionary - node_attrs = ChainMap(node_attrs, self._default_node_attrs) + node_attrs: ChainMap = ChainMap(node_attrs, self._default_node_attrs) node = pydot.Node(node_id, label=node_label, **node_attrs) parent_graph_id = "__base__" if parent_graph_id is None else parent_graph_id @@ -135,7 +135,7 @@ def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> Non """ # Use ChainMap so you don't need to construct a new dictionary - edge_attrs = ChainMap(edge_attrs, self._default_edge_attrs) + edge_attrs: ChainMap = ChainMap(edge_attrs, self._default_edge_attrs) edge = pydot.Edge(from_node_id, to_node_id, **edge_attrs) self.graph.add_edge(edge) @@ -163,7 +163,7 @@ def add_cluster( """ # Use ChainMap so you don't need to construct a new dictionary - cluster_attrs = ChainMap(cluster_attrs, self._default_cluster_attrs) + cluster_attrs: ChainMap = ChainMap(cluster_attrs, self._default_cluster_attrs) cluster = pydot.Cluster(graph_name=cluster_id, **cluster_attrs) # Puts the label in a node within the cluster. diff --git a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py index 2e935bae1b..7191cf001c 100644 --- a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py @@ -54,6 +54,15 @@ def add_cluster( ) -> None: return + def get_nodes(self) -> dict[str, dict[str, Any]]: + return {} + + def get_edges(self) -> list[dict[str, Any]]: + return [] + + def get_clusters(self) -> dict[str, dict[str, Any]]: + return {} + def to_file(self, output_filename: str) -> None: return @@ -65,12 +74,18 @@ def to_string(self) -> str: node = dag_builder.add_node("0", "node0") edge = dag_builder.add_edge("0", "1") cluster = dag_builder.add_cluster("0") + nodes = dag_builder.get_nodes() + edges = dag_builder.get_edges() + clusters = dag_builder.get_clusters() render = dag_builder.to_file("test.png") string = dag_builder.to_string() assert node is None + assert nodes == {} assert edge is None + assert edges == [] assert cluster is None + assert clusters == {} assert render is None assert string == "test" diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 11e5e4ac47..426239a929 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -231,9 +231,7 @@ def test_add_cluster_with_attrs(self): assert cluster1.get("penwidth") == 2 assert cluster1.get("fontname") == "Helvetica" - dag_builder.add_cluster( - "1", style="filled", penwidth=10, fillcolor="red" - ) + dag_builder.add_cluster("1", style="filled", penwidth=10, fillcolor="red") cluster2 = dag_builder.graph.get_subgraph("cluster_1")[0] # Make sure we can override @@ -245,6 +243,19 @@ def test_add_cluster_with_attrs(self): assert cluster2.get("fontname") == "Helvetica" +class TestGetMethods: + """Tests the get_* methods.""" + + def test_get_nodes(self): + """Tests that get_nodes works.""" + + def test_get_edges(self): + """Tests that get_edges works.""" + + def test_get_clusters(self): + """Tests that get_clusters works.""" + + class TestOutput: """Test that the graph can be outputted correctly.""" From 7e825d6cb71e5b15ad346c31abee49b87ea88ca3 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 13:47:47 -0500 Subject: [PATCH 045/111] add tests --- .../visualization/test_pydot_dag_builder.py | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 426239a929..81cff2db77 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -248,13 +248,67 @@ class TestGetMethods: def test_get_nodes(self): """Tests that get_nodes works.""" + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + dag_builder.add_cluster("c0", "my_info_node", label="my_cluster") + dag_builder.add_node("1", "node1", parent_graph_id="c0") + + nodes = dag_builder.get_nodes() + assert len(nodes) == 2 + assert len(nodes["0"]) == 4 + assert nodes["0"]["id"] == "0" + assert nodes["0"]["label"] == "node0" + assert nodes["0"]["parent_id"] == "__base__" + assert nodes["0"]["attrs"] == {} + + assert nodes["1"]["id"] == "1" + assert nodes["1"]["label"] == "node1" + assert nodes["1"]["parent_id"] == "c0" + assert nodes["1"]["attrs"] == {} def test_get_edges(self): """Tests that get_edges works.""" + dag_builder = PyDotDAGBuilder() + dag_builder.add_node("0", "node0") + dag_builder.add_node("1", "node1") + dag_builder.add_edge("0", "1") + + edges = dag_builder.get_edges() + assert len(edges) == 1 + assert edges[0]["from_id"] == "0" + assert edges[0]["to_id"] == "0" + assert edges[0]["attrs"] == {} + def test_get_clusters(self): """Tests that get_clusters works.""" + dag_builder = PyDotDAGBuilder() + dag_builder.add_cluster("0", "my_info_node", label="my_cluster") + + clusters = dag_builder.get_clusters() + assert len(clusters) == 1 + assert len(clusters["0"]) == 5 + assert clusters["0"]["id"] == "0" + assert clusters["0"]["cluster_label"] == "my_cluster" + assert clusters["0"]["node_label"] == "my_info_node" + assert clusters["0"]["parent_id"] == "__base__" + assert clusters["0"]["attrs"] == {} + + dag_builder.add_cluster( + "1", "my_other_info_node", parent_graph_id="0", label="my_nested_cluster" + ) + + clusters = dag_builder.get_clusters() + assert len(clusters) == 2 + assert len(clusters["1"]) == 5 + assert clusters["1"]["id"] == "0" + assert clusters["1"]["cluster_label"] == "my_nested_cluster" + assert clusters["1"]["node_label"] == "my_other_info_node" + assert clusters["1"]["parent_id"] == "0" + assert clusters["1"]["attrs"] == {} + class TestOutput: """Test that the graph can be outputted correctly.""" From 41fc4d183e6699c1520150826e58cb0f2f8c7345 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Mon, 24 Nov 2025 13:48:14 -0500 Subject: [PATCH 046/111] Apply suggestion from @andrijapau --- .../python_interface/visualization/test_pydot_dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 81cff2db77..d06b548478 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -278,7 +278,7 @@ def test_get_edges(self): edges = dag_builder.get_edges() assert len(edges) == 1 assert edges[0]["from_id"] == "0" - assert edges[0]["to_id"] == "0" + assert edges[0]["to_id"] == "1" assert edges[0]["attrs"] == {} def test_get_clusters(self): From 1868346de42eac6ba9bb8e5956df92bce67edf29 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 13:49:07 -0500 Subject: [PATCH 047/111] add tests --- .../visualization/test_pydot_dag_builder.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index d06b548478..805bacb6ab 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -255,8 +255,10 @@ def test_get_nodes(self): dag_builder.add_node("1", "node1", parent_graph_id="c0") nodes = dag_builder.get_nodes() + assert len(nodes) == 2 assert len(nodes["0"]) == 4 + assert nodes["0"]["id"] == "0" assert nodes["0"]["label"] == "node0" assert nodes["0"]["parent_id"] == "__base__" @@ -276,7 +278,9 @@ def test_get_edges(self): dag_builder.add_edge("0", "1") edges = dag_builder.get_edges() + assert len(edges) == 1 + assert edges[0]["from_id"] == "0" assert edges[0]["to_id"] == "1" assert edges[0]["attrs"] == {} @@ -288,7 +292,13 @@ def test_get_clusters(self): dag_builder.add_cluster("0", "my_info_node", label="my_cluster") clusters = dag_builder.get_clusters() - assert len(clusters) == 1 + + dag_builder.add_cluster( + "1", "my_other_info_node", parent_graph_id="0", label="my_nested_cluster" + ) + clusters = dag_builder.get_clusters() + assert len(clusters) == 2 + assert len(clusters["0"]) == 5 assert clusters["0"]["id"] == "0" assert clusters["0"]["cluster_label"] == "my_cluster" @@ -296,12 +306,6 @@ def test_get_clusters(self): assert clusters["0"]["parent_id"] == "__base__" assert clusters["0"]["attrs"] == {} - dag_builder.add_cluster( - "1", "my_other_info_node", parent_graph_id="0", label="my_nested_cluster" - ) - - clusters = dag_builder.get_clusters() - assert len(clusters) == 2 assert len(clusters["1"]) == 5 assert clusters["1"]["id"] == "0" assert clusters["1"]["cluster_label"] == "my_nested_cluster" From fd8d72116276b3e56e9b8f60078ddd2117845391 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 13:50:00 -0500 Subject: [PATCH 048/111] fix tests --- .../python_interface/visualization/test_pydot_dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 805bacb6ab..28b04f5f41 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -251,7 +251,7 @@ def test_get_nodes(self): dag_builder = PyDotDAGBuilder() dag_builder.add_node("0", "node0") - dag_builder.add_cluster("c0", "my_info_node", label="my_cluster") + dag_builder.add_cluster("c0") dag_builder.add_node("1", "node1", parent_graph_id="c0") nodes = dag_builder.get_nodes() From 284ba07eab92d8e6e14d16ac00259f78d23a2909 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 13:59:50 -0500 Subject: [PATCH 049/111] update tests --- .../visualization/test_pydot_dag_builder.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 28b04f5f41..2730efa412 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -250,7 +250,7 @@ def test_get_nodes(self): """Tests that get_nodes works.""" dag_builder = PyDotDAGBuilder() - dag_builder.add_node("0", "node0") + dag_builder.add_node("0", "node0", fillcolor="red") dag_builder.add_cluster("c0") dag_builder.add_node("1", "node1", parent_graph_id="c0") @@ -262,7 +262,7 @@ def test_get_nodes(self): assert nodes["0"]["id"] == "0" assert nodes["0"]["label"] == "node0" assert nodes["0"]["parent_id"] == "__base__" - assert nodes["0"]["attrs"] == {} + assert nodes["0"]["attrs"] == {"fillcolor": "red"} assert nodes["1"]["id"] == "1" assert nodes["1"]["label"] == "node1" @@ -275,7 +275,7 @@ def test_get_edges(self): dag_builder = PyDotDAGBuilder() dag_builder.add_node("0", "node0") dag_builder.add_node("1", "node1") - dag_builder.add_edge("0", "1") + dag_builder.add_edge("0", "1", penwidth=10) edges = dag_builder.get_edges() @@ -283,13 +283,13 @@ def test_get_edges(self): assert edges[0]["from_id"] == "0" assert edges[0]["to_id"] == "1" - assert edges[0]["attrs"] == {} + assert edges[0]["attrs"] == {"penwidth": 10} def test_get_clusters(self): """Tests that get_clusters works.""" dag_builder = PyDotDAGBuilder() - dag_builder.add_cluster("0", "my_info_node", label="my_cluster") + dag_builder.add_cluster("0", "my_info_node", label="my_cluster", penwidth=10) clusters = dag_builder.get_clusters() @@ -304,7 +304,7 @@ def test_get_clusters(self): assert clusters["0"]["cluster_label"] == "my_cluster" assert clusters["0"]["node_label"] == "my_info_node" assert clusters["0"]["parent_id"] == "__base__" - assert clusters["0"]["attrs"] == {} + assert clusters["0"]["attrs"] == {"penwidth": 10} assert len(clusters["1"]) == 5 assert clusters["1"]["id"] == "0" From 95ace870e83f989a1c1e24aafde9465ce22b00b0 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 14:02:05 -0500 Subject: [PATCH 050/111] fix documentation --- .../catalyst/python_interface/visualization/dag_builder.py | 3 ++- .../python_interface/visualization/pydot_dag_builder.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index d59a81b6f4..492976ed5f 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -78,7 +78,8 @@ def add_cluster( Args: cluster_id (str): Unique cluster ID to identify this cluster. node_label (str | None): The text to display on an information node within the cluster when rendered. - parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. + parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. If `None`, the cluster will be + placed on the base graph. **cluster_attrs (Any): Any additional styling keyword arguments. """ diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 9a81dc55d0..7073f9a52e 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -158,7 +158,7 @@ def add_cluster( Args: cluster_id (str): Unique cluster ID to identify this cluster. node_label (str | None): The text to display on the information node within the cluster when rendered. - parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. + parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. If `None`, the cluster will be positioned on the base graph. **cluster_attrs (Any): Any additional styling keyword arguments. """ From ec9caaeddb1b10d07dc8d3843a2a91f25b04d6e4 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 14:04:18 -0500 Subject: [PATCH 051/111] fix documentation --- frontend/catalyst/python_interface/visualization/dag_builder.py | 2 +- .../python_interface/visualization/pydot_dag_builder.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index 492976ed5f..db118b4dfb 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -99,7 +99,7 @@ def get_edges(self) -> list[dict[str, Any]]: """Retrieve the current set of edges in the graph. Returns: - edges (list[dict[str, Any]]): A list of edges where each edge contains a dictionary of information for a given edge. + edges (list[dict[str, Any]]): A list of edges where each element in the list contains a dictionary of edge information. """ raise NotImplementedError diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 7073f9a52e..2e0ba196a4 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -215,7 +215,7 @@ def get_edges(self) -> list[dict[str, Any]]: """Retrieve the current set of edges in the graph. Returns: - edges (list[dict[str, Any]]): A list of edges where each edge contains a dictionary of information for a given edge. + edges (list[dict[str, Any]]): A list of edges where each element in the list contains a dictionary of edge information. """ return self._edges From 99bd602308c03d44d763affd9025de75b8b5d62d Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 14:19:07 -0500 Subject: [PATCH 052/111] fix tests --- .../visualization/test_pydot_dag_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 2730efa412..cff4f72207 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -262,7 +262,7 @@ def test_get_nodes(self): assert nodes["0"]["id"] == "0" assert nodes["0"]["label"] == "node0" assert nodes["0"]["parent_id"] == "__base__" - assert nodes["0"]["attrs"] == {"fillcolor": "red"} + assert nodes["0"]["attrs"]["fillcolor"] == "red" assert nodes["1"]["id"] == "1" assert nodes["1"]["label"] == "node1" @@ -283,7 +283,7 @@ def test_get_edges(self): assert edges[0]["from_id"] == "0" assert edges[0]["to_id"] == "1" - assert edges[0]["attrs"] == {"penwidth": 10} + assert edges[0]["attrs"]["penwidth"] == 10 def test_get_clusters(self): """Tests that get_clusters works.""" @@ -304,7 +304,7 @@ def test_get_clusters(self): assert clusters["0"]["cluster_label"] == "my_cluster" assert clusters["0"]["node_label"] == "my_info_node" assert clusters["0"]["parent_id"] == "__base__" - assert clusters["0"]["attrs"] == {"penwidth": 10} + assert clusters["0"]["attrs"]["penwidth"] == 10 assert len(clusters["1"]) == 5 assert clusters["1"]["id"] == "0" From 7622b502560dd67f282162f1b4539ef19e568985 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 14:35:24 -0500 Subject: [PATCH 053/111] whoops --- .../python_interface/visualization/test_pydot_dag_builder.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index cff4f72207..dbec2ced99 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -267,7 +267,6 @@ def test_get_nodes(self): assert nodes["1"]["id"] == "1" assert nodes["1"]["label"] == "node1" assert nodes["1"]["parent_id"] == "c0" - assert nodes["1"]["attrs"] == {} def test_get_edges(self): """Tests that get_edges works.""" @@ -311,7 +310,6 @@ def test_get_clusters(self): assert clusters["1"]["cluster_label"] == "my_nested_cluster" assert clusters["1"]["node_label"] == "my_other_info_node" assert clusters["1"]["parent_id"] == "0" - assert clusters["1"]["attrs"] == {} class TestOutput: From 86e8389feaa02b555ec1a1498558a46d4a17a8f4 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 14:58:15 -0500 Subject: [PATCH 054/111] whoops --- .../python_interface/visualization/test_pydot_dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index dbec2ced99..40fce1d641 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -306,7 +306,7 @@ def test_get_clusters(self): assert clusters["0"]["attrs"]["penwidth"] == 10 assert len(clusters["1"]) == 5 - assert clusters["1"]["id"] == "0" + assert clusters["1"]["id"] == "1" assert clusters["1"]["cluster_label"] == "my_nested_cluster" assert clusters["1"]["node_label"] == "my_other_info_node" assert clusters["1"]["parent_id"] == "0" From a5dc1ac0f51812b1d40d06c4fa3185d329499574 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 15:12:55 -0500 Subject: [PATCH 055/111] rename a bunch of stuff --- .../visualization/dag_builder.py | 28 +++++----- .../visualization/pydot_dag_builder.py | 52 +++++++++---------- .../visualization/test_dag_builder.py | 14 +++-- .../visualization/test_pydot_dag_builder.py | 12 ++--- 4 files changed, 51 insertions(+), 55 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index db118b4dfb..6ea69de6e7 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -32,31 +32,29 @@ class DAGBuilder(ABC): @abstractmethod def add_node( self, - node_id: NodeID, - node_label: str, - parent_graph_id: str | None = None, + id: NodeID, + label: str, + cluster_id: ClusterID | None = None, **node_attrs: Any, ) -> None: """Add a single node to the graph. Args: - node_id (str): Unique node ID to identify this node. - node_label (str): The text to display on the node when rendered. - parent_graph_id (str | None): Optional ID of the cluster this node belongs to. + id (str): Unique node ID to identify this node. + label (str): The text to display on the node when rendered. + cluster_id (str | None): Optional ID of the cluster this node belongs to. **node_attrs (Any): Any additional styling keyword arguments. """ raise NotImplementedError @abstractmethod - def add_edge( - self, from_node_id: NodeID, to_node_id: NodeID, **edge_attrs: Any - ) -> None: + def add_edge(self, from_id: NodeID, to_id: NodeID, **edge_attrs: Any) -> None: """Add a single directed edge between nodes in the graph. Args: - from_node_id (str): The unique ID of the source node. - to_node_id (str): The unique ID of the destination node. + from_id (str): The unique ID of the source node. + to_id (str): The unique ID of the destination node. **edge_attrs (Any): Any additional styling keyword arguments. """ @@ -65,9 +63,9 @@ def add_edge( @abstractmethod def add_cluster( self, - cluster_id: ClusterID, + id: ClusterID, node_label: str | None = None, - parent_graph_id: ClusterID | None = None, + cluster_id: ClusterID | None = None, **cluster_attrs: Any, ) -> None: """Add a single cluster to the graph. @@ -76,9 +74,9 @@ def add_cluster( within it are visually and logically grouped. Args: - cluster_id (str): Unique cluster ID to identify this cluster. + id (str): Unique cluster ID to identify this cluster. node_label (str | None): The text to display on an information node within the cluster when rendered. - parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. If `None`, the cluster will be + cluster_id (str | None): Optional ID of the cluster this cluster belongs to. If `None`, the cluster will be placed on the base graph. **cluster_attrs (Any): Any additional styling keyword arguments. diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 2e0ba196a4..57b488b530 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -97,57 +97,57 @@ def __init__( def add_node( self, - node_id: str, - node_label: str, - parent_graph_id: str | None = None, + id: str, + label: str, + cluster_id: str | None = None, **node_attrs: Any, ) -> None: """Add a single node to the graph. Args: - node_id (str): Unique node ID to identify this node. - node_label (str): The text to display on the node when rendered. - parent_graph_id (str | None): Optional ID of the cluster this node belongs to. + id (str): Unique node ID to identify this node. + label (str): The text to display on the node when rendered. + cluster_id (str | None): Optional ID of the cluster this node belongs to. **node_attrs (Any): Any additional styling keyword arguments. """ # Use ChainMap so you don't need to construct a new dictionary node_attrs: ChainMap = ChainMap(node_attrs, self._default_node_attrs) - node = pydot.Node(node_id, label=node_label, **node_attrs) - parent_graph_id = "__base__" if parent_graph_id is None else parent_graph_id + node = pydot.Node(id, label=label, **node_attrs) + parent_graph_id = "__base__" if cluster_id is None else cluster_id self._subgraphs[parent_graph_id].add_node(node) - self._nodes[node_id] = { - "id": node_id, - "label": node_label, - "parent_id": parent_graph_id, + self._nodes[id] = { + "id": id, + "label": label, + "cluster_id": cluster_id, "attrs": dict(node_attrs), } - def add_edge(self, from_node_id: str, to_node_id: str, **edge_attrs: Any) -> None: + def add_edge(self, from_id: str, to_id: str, **edge_attrs: Any) -> None: """Add a single directed edge between nodes in the graph. Args: - from_node_id (str): The unique ID of the source node. - to_node_id (str): The unique ID of the destination node. + from_id (str): The unique ID of the source node. + to_id (str): The unique ID of the destination node. **edge_attrs (Any): Any additional styling keyword arguments. """ # Use ChainMap so you don't need to construct a new dictionary edge_attrs: ChainMap = ChainMap(edge_attrs, self._default_edge_attrs) - edge = pydot.Edge(from_node_id, to_node_id, **edge_attrs) + edge = pydot.Edge(from_id, to_id, **edge_attrs) self.graph.add_edge(edge) self._edges.append( - {"from_id": from_node_id, "to_id": to_node_id, "attrs": dict(edge_attrs)} + {"from_id": from_id, "to_id": to_id, "attrs": dict(edge_attrs)} ) def add_cluster( self, - cluster_id: str, + id: str, node_label: str | None = None, - parent_graph_id: str | None = None, + cluster_id: str | None = None, **cluster_attrs: Any, ) -> None: """Add a single cluster to the graph. @@ -156,15 +156,15 @@ def add_cluster( within it are visually and logically grouped. Args: - cluster_id (str): Unique cluster ID to identify this cluster. + id (str): Unique cluster ID to identify this cluster. node_label (str | None): The text to display on the information node within the cluster when rendered. - parent_graph_id (str | None): Optional ID of the cluster this cluster belongs to. If `None`, the cluster will be positioned on the base graph. + cluster_id (str | None): Optional ID of the cluster this cluster belongs to. If `None`, the cluster will be positioned on the base graph. **cluster_attrs (Any): Any additional styling keyword arguments. """ # Use ChainMap so you don't need to construct a new dictionary cluster_attrs: ChainMap = ChainMap(cluster_attrs, self._default_cluster_attrs) - cluster = pydot.Cluster(graph_name=cluster_id, **cluster_attrs) + cluster = pydot.Cluster(graph_name=id, **cluster_attrs) # Puts the label in a node within the cluster. # Ensures that any edges connecting nodes through the cluster @@ -190,13 +190,13 @@ def add_cluster( cluster.add_subgraph(rank_subgraph) cluster.add_node(node) - self._subgraphs[cluster_id] = cluster + self._subgraphs[id] = cluster - parent_graph_id = "__base__" if parent_graph_id is None else parent_graph_id + parent_graph_id = "__base__" if cluster_id is None else cluster_id self._subgraphs[parent_graph_id].add_subgraph(cluster) - self._clusters[cluster_id] = { - "id": cluster_id, + self._clusters[id] = { + "id": id, "cluster_label": cluster_attrs.get("label"), "node_label": node_label, "parent_id": parent_graph_id, diff --git a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py index 7191cf001c..3b7d61c49d 100644 --- a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py @@ -33,23 +33,21 @@ class ConcreteDAGBuilder(DAGBuilder): def add_node( self, - node_id: str, - node_label: str, - parent_graph_id: str | None = None, + id: str, + label: str, + cluster_id: str | None = None, **node_attrs: Any, ) -> None: return - def add_edge( - self, from_node_id: str, to_node_id: str, **edge_attrs: Any - ) -> None: + def add_edge(self, from_id: str, to_id: str, **edge_attrs: Any) -> None: return def add_cluster( self, - cluster_id: str, + id: str, node_label: str | None = None, - parent_graph_id: str | None = None, + cluster_id: str | None = None, **cluster_attrs: Any, ) -> None: return diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 40fce1d641..b0a17441ea 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -90,7 +90,7 @@ def test_add_node_to_parent_graph(self): dag_builder.add_cluster("c0") # Create node inside cluster - dag_builder.add_node("1", "node1", parent_graph_id="c0") + dag_builder.add_node("1", "node1", cluster_id="c0") # Verify graph structure root_graph = dag_builder.graph @@ -120,11 +120,11 @@ def test_add_cluster_to_parent_graph(self): dag_builder.add_cluster("c0") # Level 1 (Inside c0): Add node on outer cluster and create new cluster on top - dag_builder.add_node("n_outer", "node_outer", parent_graph_id="c0") - dag_builder.add_cluster("c1", parent_graph_id="c0") + dag_builder.add_node("n_outer", "node_outer", cluster_id="c0") + dag_builder.add_cluster("c1", cluster_id="c0") # Level 2 (Inside c1): Add node on second cluster - dag_builder.add_node("n_inner", "node_inner", parent_graph_id="c1") + dag_builder.add_node("n_inner", "node_inner", cluster_id="c1") root_graph = dag_builder.graph @@ -252,7 +252,7 @@ def test_get_nodes(self): dag_builder.add_node("0", "node0", fillcolor="red") dag_builder.add_cluster("c0") - dag_builder.add_node("1", "node1", parent_graph_id="c0") + dag_builder.add_node("1", "node1", cluster_id="c0") nodes = dag_builder.get_nodes() @@ -293,7 +293,7 @@ def test_get_clusters(self): clusters = dag_builder.get_clusters() dag_builder.add_cluster( - "1", "my_other_info_node", parent_graph_id="0", label="my_nested_cluster" + "1", "my_other_info_node", cluster_id="0", label="my_nested_cluster" ) clusters = dag_builder.get_clusters() assert len(clusters) == 2 From fe7ef4772df9f9337b776cf46998e7cf32b2e7be Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 15:13:51 -0500 Subject: [PATCH 056/111] fix documentation --- .../catalyst/python_interface/visualization/dag_builder.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index 6ea69de6e7..becd926b91 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -42,7 +42,8 @@ def add_node( Args: id (str): Unique node ID to identify this node. label (str): The text to display on the node when rendered. - cluster_id (str | None): Optional ID of the cluster this node belongs to. + cluster_id (str | None): Optional ID of the cluster this node belongs to. If `None`, this node gets + added on the base graph. **node_attrs (Any): Any additional styling keyword arguments. """ From b2a13cfc8fd64fc9ab94f2be3005407a6617a8a1 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 15:14:49 -0500 Subject: [PATCH 057/111] rename a bunch of stuff --- .../visualization/pydot_dag_builder.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 57b488b530..831b181fec 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -114,9 +114,9 @@ def add_node( # Use ChainMap so you don't need to construct a new dictionary node_attrs: ChainMap = ChainMap(node_attrs, self._default_node_attrs) node = pydot.Node(id, label=label, **node_attrs) - parent_graph_id = "__base__" if cluster_id is None else cluster_id + cluster_id = "__base__" if cluster_id is None else cluster_id - self._subgraphs[parent_graph_id].add_node(node) + self._subgraphs[cluster_id].add_node(node) self._nodes[id] = { "id": id, @@ -192,14 +192,14 @@ def add_cluster( self._subgraphs[id] = cluster - parent_graph_id = "__base__" if cluster_id is None else cluster_id - self._subgraphs[parent_graph_id].add_subgraph(cluster) + cluster_id = "__base__" if cluster_id is None else cluster_id + self._subgraphs[cluster_id].add_subgraph(cluster) self._clusters[id] = { "id": id, "cluster_label": cluster_attrs.get("label"), "node_label": node_label, - "parent_id": parent_graph_id, + "parent_id": cluster_id, "attrs": dict(cluster_attrs), } From abfd9328fb046846ee9de01332d8d3ad6258fefb Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 15:19:09 -0500 Subject: [PATCH 058/111] add dev comment --- .../catalyst/python_interface/visualization/pydot_dag_builder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 831b181fec..4cd63de03f 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -54,6 +54,7 @@ def __init__( graph_type="digraph", rankdir="TB", compound="true", strict=True ) # Create cache for easy look-up + # TODO: Get rid of this and use self._clusters if possible self._subgraphs: dict[str, pydot.Graph] = {} self._subgraphs["__base__"] = self.graph From a481a048748c4b4ed0e27a1b584a593cc3f44a4d Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 15:22:03 -0500 Subject: [PATCH 059/111] rename --- .../visualization/dag_builder.py | 12 ++++++------ .../visualization/pydot_dag_builder.py | 18 +++++++++--------- .../visualization/test_dag_builder.py | 6 +++--- .../visualization/test_pydot_dag_builder.py | 6 +++--- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index becd926b91..8a2fff6960 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -35,7 +35,7 @@ def add_node( id: NodeID, label: str, cluster_id: ClusterID | None = None, - **node_attrs: Any, + **attrs: Any, ) -> None: """Add a single node to the graph. @@ -44,19 +44,19 @@ def add_node( label (str): The text to display on the node when rendered. cluster_id (str | None): Optional ID of the cluster this node belongs to. If `None`, this node gets added on the base graph. - **node_attrs (Any): Any additional styling keyword arguments. + **attrs (Any): Any additional styling keyword arguments. """ raise NotImplementedError @abstractmethod - def add_edge(self, from_id: NodeID, to_id: NodeID, **edge_attrs: Any) -> None: + def add_edge(self, from_id: NodeID, to_id: NodeID, **attrs: Any) -> None: """Add a single directed edge between nodes in the graph. Args: from_id (str): The unique ID of the source node. to_id (str): The unique ID of the destination node. - **edge_attrs (Any): Any additional styling keyword arguments. + **attrs (Any): Any additional styling keyword arguments. """ raise NotImplementedError @@ -67,7 +67,7 @@ def add_cluster( id: ClusterID, node_label: str | None = None, cluster_id: ClusterID | None = None, - **cluster_attrs: Any, + **attrs: Any, ) -> None: """Add a single cluster to the graph. @@ -79,7 +79,7 @@ def add_cluster( node_label (str | None): The text to display on an information node within the cluster when rendered. cluster_id (str | None): Optional ID of the cluster this cluster belongs to. If `None`, the cluster will be placed on the base graph. - **cluster_attrs (Any): Any additional styling keyword arguments. + **attrs (Any): Any additional styling keyword arguments. """ raise NotImplementedError diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 4cd63de03f..47e9f584f9 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -101,7 +101,7 @@ def add_node( id: str, label: str, cluster_id: str | None = None, - **node_attrs: Any, + **attrs: Any, ) -> None: """Add a single node to the graph. @@ -109,11 +109,11 @@ def add_node( id (str): Unique node ID to identify this node. label (str): The text to display on the node when rendered. cluster_id (str | None): Optional ID of the cluster this node belongs to. - **node_attrs (Any): Any additional styling keyword arguments. + **attrs (Any): Any additional styling keyword arguments. """ # Use ChainMap so you don't need to construct a new dictionary - node_attrs: ChainMap = ChainMap(node_attrs, self._default_node_attrs) + node_attrs: ChainMap = ChainMap(attrs, self._default_node_attrs) node = pydot.Node(id, label=label, **node_attrs) cluster_id = "__base__" if cluster_id is None else cluster_id @@ -126,17 +126,17 @@ def add_node( "attrs": dict(node_attrs), } - def add_edge(self, from_id: str, to_id: str, **edge_attrs: Any) -> None: + def add_edge(self, from_id: str, to_id: str, **attrs: Any) -> None: """Add a single directed edge between nodes in the graph. Args: from_id (str): The unique ID of the source node. to_id (str): The unique ID of the destination node. - **edge_attrs (Any): Any additional styling keyword arguments. + **attrs (Any): Any additional styling keyword arguments. """ # Use ChainMap so you don't need to construct a new dictionary - edge_attrs: ChainMap = ChainMap(edge_attrs, self._default_edge_attrs) + edge_attrs: ChainMap = ChainMap(attrs, self._default_edge_attrs) edge = pydot.Edge(from_id, to_id, **edge_attrs) self.graph.add_edge(edge) @@ -149,7 +149,7 @@ def add_cluster( id: str, node_label: str | None = None, cluster_id: str | None = None, - **cluster_attrs: Any, + **attrs: Any, ) -> None: """Add a single cluster to the graph. @@ -160,11 +160,11 @@ def add_cluster( id (str): Unique cluster ID to identify this cluster. node_label (str | None): The text to display on the information node within the cluster when rendered. cluster_id (str | None): Optional ID of the cluster this cluster belongs to. If `None`, the cluster will be positioned on the base graph. - **cluster_attrs (Any): Any additional styling keyword arguments. + **attrs (Any): Any additional styling keyword arguments. """ # Use ChainMap so you don't need to construct a new dictionary - cluster_attrs: ChainMap = ChainMap(cluster_attrs, self._default_cluster_attrs) + cluster_attrs: ChainMap = ChainMap(attrs, self._default_cluster_attrs) cluster = pydot.Cluster(graph_name=id, **cluster_attrs) # Puts the label in a node within the cluster. diff --git a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py index 3b7d61c49d..2ff5b8255f 100644 --- a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py @@ -36,11 +36,11 @@ def add_node( id: str, label: str, cluster_id: str | None = None, - **node_attrs: Any, + **attrs: Any, ) -> None: return - def add_edge(self, from_id: str, to_id: str, **edge_attrs: Any) -> None: + def add_edge(self, from_id: str, to_id: str, **attrs: Any) -> None: return def add_cluster( @@ -48,7 +48,7 @@ def add_cluster( id: str, node_label: str | None = None, cluster_id: str | None = None, - **cluster_attrs: Any, + **attrs: Any, ) -> None: return diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index b0a17441ea..808dcedc87 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -176,7 +176,7 @@ def test_default_graph_attrs(self): def test_add_node_with_attrs(self): """Tests that default attributes are applied and can be overridden.""" dag_builder = PyDotDAGBuilder( - node_attrs={"fillcolor": "lightblue", "penwidth": 3} + attrs={"fillcolor": "lightblue", "penwidth": 3} ) # Defaults @@ -194,7 +194,7 @@ def test_add_node_with_attrs(self): @pytest.mark.unit def test_add_edge_with_attrs(self): """Tests that default attributes are applied and can be overridden.""" - dag_builder = PyDotDAGBuilder(edge_attrs={"color": "lightblue4", "penwidth": 3}) + dag_builder = PyDotDAGBuilder(attrs={"color": "lightblue4", "penwidth": 3}) dag_builder.add_node("0", "node0") dag_builder.add_node("1", "node1") @@ -214,7 +214,7 @@ def test_add_edge_with_attrs(self): def test_add_cluster_with_attrs(self): """Tests that default cluster attributes are applied and can be overridden.""" dag_builder = PyDotDAGBuilder( - cluster_attrs={ + attrs={ "style": "solid", "fillcolor": None, "penwidth": 2, From 9a18e5c25bb408e41828e3d28878df8831ba0a86 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 16:10:16 -0500 Subject: [PATCH 060/111] update test --- .../python_interface/visualization/test_pydot_dag_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 808dcedc87..12c44df287 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -261,12 +261,12 @@ def test_get_nodes(self): assert nodes["0"]["id"] == "0" assert nodes["0"]["label"] == "node0" - assert nodes["0"]["parent_id"] == "__base__" + assert nodes["0"]["cluster_id"] == "__base__" assert nodes["0"]["attrs"]["fillcolor"] == "red" assert nodes["1"]["id"] == "1" assert nodes["1"]["label"] == "node1" - assert nodes["1"]["parent_id"] == "c0" + assert nodes["1"]["cluster_id"] == "c0" def test_get_edges(self): """Tests that get_edges works.""" From 0f6ab76e63e5573df041ffaa80605c374e9e07c6 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 16:56:56 -0500 Subject: [PATCH 061/111] add immutability tests --- .../visualization/pydot_dag_builder.py | 6 +- .../visualization/test_pydot_dag_builder.py | 59 ++++++++++++++++++- 2 files changed, 59 insertions(+), 6 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 47e9f584f9..ea6417d2d4 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -210,7 +210,7 @@ def get_nodes(self) -> dict[str, dict[str, Any]]: Returns: nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to it's node information. """ - return self._nodes + return self._nodes.copy() def get_edges(self) -> list[dict[str, Any]]: """Retrieve the current set of edges in the graph. @@ -218,7 +218,7 @@ def get_edges(self) -> list[dict[str, Any]]: Returns: edges (list[dict[str, Any]]): A list of edges where each element in the list contains a dictionary of edge information. """ - return self._edges + return self._edges.copy() def get_clusters(self) -> dict[str, dict[str, Any]]: """Retrieve the current set of clusters in the graph. @@ -226,7 +226,7 @@ def get_clusters(self) -> dict[str, dict[str, Any]]: Returns: clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to it's cluster information. """ - return self._clusters + return self._clusters.copy() def to_file(self, output_filename: str) -> None: """Save the graph to a file. diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 12c44df287..db6e9fb4da 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -17,6 +17,8 @@ import pytest +from frontend.catalyst.python_interface.visualization import dag_builder + pydot = pytest.importorskip("pydot") pytestmark = pytest.mark.usefixtures("requires_xdsl") # pylint: disable=wrong-import-position @@ -175,9 +177,7 @@ def test_default_graph_attrs(self): @pytest.mark.unit def test_add_node_with_attrs(self): """Tests that default attributes are applied and can be overridden.""" - dag_builder = PyDotDAGBuilder( - attrs={"fillcolor": "lightblue", "penwidth": 3} - ) + dag_builder = PyDotDAGBuilder(attrs={"fillcolor": "lightblue", "penwidth": 3}) # Defaults dag_builder.add_node("0", "node0") @@ -268,6 +268,23 @@ def test_get_nodes(self): assert nodes["1"]["label"] == "node1" assert nodes["1"]["cluster_id"] == "c0" + def test_get_nodes_doesnt_mutate(self): + """Tests that get_nodes doesn't mutate state""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + + old_nodes = dag_builder.get_nodes() + + dag_builder.add_node("1", "node1") + + new_nodes = dag_builder.get_nodes() + + assert old_nodes is not new_nodes + assert len(old_nodes) == 1 + assert len(new_nodes) == 2 + def test_get_edges(self): """Tests that get_edges works.""" @@ -284,6 +301,25 @@ def test_get_edges(self): assert edges[0]["to_id"] == "1" assert edges[0]["attrs"]["penwidth"] == 10 + def test_get_edges_doesnt_mutate(self): + """Tests that get_edges doesn't mutated.""" + + dag_builder = PyDotDAGBuilder() + dag_builder.add_node("0", "node0") + dag_builder.add_node("1", "node1") + dag_builder.add_edge("0", "1") + + old_edges = dag_builder.get_edges() + + dag_builder.add_node("2", "node2") + dag_builder.add_edge("1", "2") + + new_edges = dag_builder.get_edges() + + assert old_edges is not new_edges + assert len(old_edges) == 1 + assert len(new_edges) == 2 + def test_get_clusters(self): """Tests that get_clusters works.""" @@ -311,6 +347,23 @@ def test_get_clusters(self): assert clusters["1"]["node_label"] == "my_other_info_node" assert clusters["1"]["parent_id"] == "0" + def test_get_clusters_doesnt_mutate(self): + """Tests that get_clusters doesn't mutate state""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_cluster("0") + + old_clusters = dag_builder.get_clusters() + + dag_builder.add_cluster("1") + + new_clusters = dag_builder.get_clusters() + + assert old_clusters is not new_clusters + assert len(old_clusters) == 1 + assert len(new_clusters) == 2 + class TestOutput: """Test that the graph can be outputted correctly.""" From 6d15d7be495875fe325e7b0d7b010ed97ca19e7f Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 08:54:05 -0500 Subject: [PATCH 062/111] clean-up --- .../python_interface/visualization/construct_circuit_dag.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index f662e1b7b3..4ab442e8cb 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -52,9 +52,8 @@ def construct(self, module: builtin.ModuleOp) -> None: self._visit(op) # ======================= - # 2. HIERARCHY TRAVERSAL + # 2. IR TRAVERSAL # ======================= - # These methods navigate the recursive IR hierarchy (Op -> Region -> Block -> Op). @_visit.register def _operation(self, operation: Operation) -> None: From 888d025dafdb04b2d74314c0d5d00aef61e59be9 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 08:55:58 -0500 Subject: [PATCH 063/111] clean-up --- .../python_interface/visualization/construct_circuit_dag.py | 1 - 1 file changed, 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 4ab442e8cb..419943c361 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -44,7 +44,6 @@ def __init__(self, dag_builder: DAGBuilder) -> None: def _visit(self, op: Any) -> None: """Central dispatch method (Visitor Pattern). Routes the operation 'op' to the specialized handler registered for its type.""" - pass def construct(self, module: builtin.ModuleOp) -> None: """Constructs the DAG from the module.""" From d4b34f96d5995c8252faa63c61d84745a6597461 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:01:28 -0500 Subject: [PATCH 064/111] whoops --- .../python_interface/visualization/pydot_dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index ea6417d2d4..e8abda83da 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -200,7 +200,7 @@ def add_cluster( "id": id, "cluster_label": cluster_attrs.get("label"), "node_label": node_label, - "parent_id": cluster_id, + "cluster_id": cluster_id, "attrs": dict(cluster_attrs), } From 33aa334ba1286f8f55b6980a7951bbc8a0747bae Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:15:51 -0500 Subject: [PATCH 065/111] fix test --- .../python_interface/visualization/test_pydot_dag_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index db6e9fb4da..761be2da21 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -338,14 +338,14 @@ def test_get_clusters(self): assert clusters["0"]["id"] == "0" assert clusters["0"]["cluster_label"] == "my_cluster" assert clusters["0"]["node_label"] == "my_info_node" - assert clusters["0"]["parent_id"] == "__base__" + assert clusters["0"]["cluster_id"] == "__base__" assert clusters["0"]["attrs"]["penwidth"] == 10 assert len(clusters["1"]) == 5 assert clusters["1"]["id"] == "1" assert clusters["1"]["cluster_label"] == "my_nested_cluster" assert clusters["1"]["node_label"] == "my_other_info_node" - assert clusters["1"]["parent_id"] == "0" + assert clusters["1"]["cluster_id"] == "0" def test_get_clusters_doesnt_mutate(self): """Tests that get_clusters doesn't mutate state""" From d50cfdc75bd2fe6fdbe851bd4927b45b846204b5 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:16:58 -0500 Subject: [PATCH 066/111] whoops --- .../python_interface/visualization/test_pydot_dag_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index db6e9fb4da..761be2da21 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -338,14 +338,14 @@ def test_get_clusters(self): assert clusters["0"]["id"] == "0" assert clusters["0"]["cluster_label"] == "my_cluster" assert clusters["0"]["node_label"] == "my_info_node" - assert clusters["0"]["parent_id"] == "__base__" + assert clusters["0"]["cluster_id"] == "__base__" assert clusters["0"]["attrs"]["penwidth"] == 10 assert len(clusters["1"]) == 5 assert clusters["1"]["id"] == "1" assert clusters["1"]["cluster_label"] == "my_nested_cluster" assert clusters["1"]["node_label"] == "my_other_info_node" - assert clusters["1"]["parent_id"] == "0" + assert clusters["1"]["cluster_id"] == "0" def test_get_clusters_doesnt_mutate(self): """Tests that get_clusters doesn't mutate state""" From a364869a27f8a78490be34403fe30038af4fd24f Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:22:41 -0500 Subject: [PATCH 067/111] cleanup --- .../python_interface/visualization/construct_circuit_dag.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 419943c361..d497a95d2e 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -17,10 +17,9 @@ from functools import singledispatchmethod from typing import Any -from xdsl.dialects import builtin, scf +from xdsl.dialects import builtin from xdsl.ir import Block, Operation, Region -from catalyst.python_interface.dialects import quantum from catalyst.python_interface.visualization.dag_builder import DAGBuilder From cddb234a7b3c20d7d6657e0e4aebfdfb5468b80a Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:24:28 -0500 Subject: [PATCH 068/111] clean-up --- .../visualization/construct_circuit_dag.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index d497a95d2e..0b6765ba3b 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -24,11 +24,11 @@ class ConstructCircuitDAG: - """A tool that analyzes an xDSL module and constructs a Directed Acyclic Graph (DAG) - using an injected DAGBuilder instance. This tool does not mutate the xDSL module.""" + """A tool that traverses an xDSL module and constructs a Directed Acyclic Graph (DAG) + of it's quantum program using an injected DAGBuilder instance. This tool does not mutate the xDSL module.""" def __init__(self, dag_builder: DAGBuilder) -> None: - """Initialize the analysis pass by injecting the DAG builder dependency. + """Initialize the utility by injecting the DAG builder dependency. Args: dag_builder (DAGBuilder): The concrete builder instance used for graph construction. @@ -45,7 +45,12 @@ def _visit(self, op: Any) -> None: to the specialized handler registered for its type.""" def construct(self, module: builtin.ModuleOp) -> None: - """Constructs the DAG from the module.""" + """Constructs the DAG from the module. + + Args: + module (xdsl.builtin.ModuleOp): The module containing the quantum program to visualize. + + """ for op in module.ops: self._visit(op) From 1b5210c512192826b1c8e5e61162eda0fc108bd4 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:26:05 -0500 Subject: [PATCH 069/111] fix formatting issue --- .../python_interface/visualization/construct_circuit_dag.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 0b6765ba3b..45a15cc2d2 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -25,7 +25,8 @@ class ConstructCircuitDAG: """A tool that traverses an xDSL module and constructs a Directed Acyclic Graph (DAG) - of it's quantum program using an injected DAGBuilder instance. This tool does not mutate the xDSL module.""" + of it's quantum program using an injected DAGBuilder instance. This tool does not mutate the xDSL module. + """ def __init__(self, dag_builder: DAGBuilder) -> None: """Initialize the utility by injecting the DAG builder dependency. From 14b28bb5fb2be009c3ea237dce271e3cc4942de5 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:28:48 -0500 Subject: [PATCH 070/111] isort --- .../visualization/test_construct_circuit_dag.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 5a55cd0ba7..ae449bb4c5 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -19,15 +19,16 @@ pytestmark = pytest.mark.usefixtures("requires_xdsl") +from xdsl.dialects import test +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir.core import Block, Region + # pylint: disable=wrong-import-position # This import needs to be after pytest in order to prevent ImportErrors from catalyst.python_interface.visualization.construct_circuit_dag import ( ConstructCircuitDAG, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder -from xdsl.dialects import test -from xdsl.dialects.builtin import ModuleOp -from xdsl.ir.core import Block, Region class TestInitialization: From a1e9211922769cbd3ca1bd4276d945a81315cd98 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:40:31 -0500 Subject: [PATCH 071/111] update tests --- .../test_construct_circuit_dag.py | 97 +++++++++++++++---- 1 file changed, 78 insertions(+), 19 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index ae449bb4c5..a4b9f63765 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -19,46 +19,105 @@ pytestmark = pytest.mark.usefixtures("requires_xdsl") -from xdsl.dialects import test -from xdsl.dialects.builtin import ModuleOp -from xdsl.ir.core import Block, Region - # pylint: disable=wrong-import-position # This import needs to be after pytest in order to prevent ImportErrors from catalyst.python_interface.visualization.construct_circuit_dag import ( ConstructCircuitDAG, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder +from xdsl.dialects import test +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir.core import Block, Region -class TestInitialization: - """Tests that the state is correctly initialized.""" - - def test_dependency_injection(self): - """Tests that relevant dependencies are injected.""" - - mock_dag_builder = Mock(DAGBuilder) - utility = ConstructCircuitDAG(mock_dag_builder) - assert utility.dag_builder is mock_dag_builder - - +class FakeDAGBuilder(DAGBuilder): + """ + A concrete implementation of DAGBuilder used ONLY for testing. + It stores all graph manipulation calls in simple Python dictionaries + for easy assertion of the final graph state. + """ + + def __init__(self): + self._nodes = {} + self._edges = [] + self._clusters = {} + + def add_node(self, id, label, cluster_id=None, **attrs) -> None: + cluster_id = "__base__" if cluster_id is None else cluster_id + self._nodes[id] = { + "id": id, + "label": label, + "cluster_id": cluster_id, + "attrs": attrs, + } + + def add_edge(self, from_id: str, to_id: str, **attrs) -> None: + self._edges.append( + { + "from": from_id, + "to": to_id, + "attrs": attrs, + } + ) + + def add_cluster( + self, + id, + node_label=None, + cluster_id=None, + **attrs, + ) -> None: + cluster_id = "__base__" if cluster_id is None else cluster_id + self._clusters[id] = { + "id": id, + "label": node_label, + "cluster_id": cluster_id, + "attrs": attrs, + } + + def get_nodes(self): + return self._nodes.copy() + + def get_edges(self): + return self._edges.copy() + + def get_clusters(self): + return self._clusters.copy() + + def to_file(self, output_filename): + pass + + def to_string(self) -> str: + return "graph" + + +@pytest.mark.unit +def test_dependency_injection(): + """Tests that relevant dependencies are injected.""" + + dag_builder = FakeDAGBuilder() + utility = ConstructCircuitDAG(dag_builder) + assert utility.dag_builder is dag_builder + + +@pytest.mark.unit def test_does_not_mutate_module(): """Test that the module is not mutated.""" - # Create block containing some ops + # Create module op = test.TestOp() block = Block(ops=[op]) - # Create region containing some blocks region = Region(blocks=[block]) - # Create op containing the regions container_op = test.TestOp(regions=[region]) - # Create module op to house it all module_op = ModuleOp(ops=[container_op]) + # Save state before module_op_str_before = str(module_op) + # Process module mock_dag_builder = Mock(DAGBuilder) utility = ConstructCircuitDAG(mock_dag_builder) utility.construct(module_op) + # Ensure not mutated assert str(module_op) == module_op_str_before From aad7449103d4baf71a5213a92f21b34fd01832bb Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Tue, 25 Nov 2025 09:42:11 -0500 Subject: [PATCH 072/111] Apply suggestion from @andrijapau --- .../visualization/test_construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index a4b9f63765..d9d677c153 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -33,7 +33,7 @@ class FakeDAGBuilder(DAGBuilder): """ A concrete implementation of DAGBuilder used ONLY for testing. - It stores all graph manipulation calls in simple Python dictionaries + It stores all graph manipulation calls in data structures for easy assertion of the final graph state. """ From 8225658dde58941bf298985fef93798dc3448315 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:45:07 -0500 Subject: [PATCH 073/111] isort --- .../visualization/test_construct_circuit_dag.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index d9d677c153..cd44eef0cb 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -19,15 +19,16 @@ pytestmark = pytest.mark.usefixtures("requires_xdsl") +from xdsl.dialects import test +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir.core import Block, Region + # pylint: disable=wrong-import-position # This import needs to be after pytest in order to prevent ImportErrors from catalyst.python_interface.visualization.construct_circuit_dag import ( ConstructCircuitDAG, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder -from xdsl.dialects import test -from xdsl.dialects.builtin import ModuleOp -from xdsl.ir.core import Block, Region class FakeDAGBuilder(DAGBuilder): From 07a059774729df3cd328c7642d1a222e8ecb1b71 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Tue, 25 Nov 2025 16:25:44 -0500 Subject: [PATCH 074/111] Update frontend/catalyst/python_interface/visualization/construct_circuit_dag.py Co-authored-by: Mudit Pandey <18223836+mudit2812@users.noreply.github.com> --- .../python_interface/visualization/construct_circuit_dag.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 45a15cc2d2..0526e70b0d 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -29,11 +29,6 @@ class ConstructCircuitDAG: """ def __init__(self, dag_builder: DAGBuilder) -> None: - """Initialize the utility by injecting the DAG builder dependency. - - Args: - dag_builder (DAGBuilder): The concrete builder instance used for graph construction. - """ self.dag_builder: DAGBuilder = dag_builder # ================================= From 40b1eb4885996f505dfdeeec1719c788946aa2f7 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 16:26:30 -0500 Subject: [PATCH 075/111] move things around --- .../visualization/construct_circuit_dag.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 0526e70b0d..401a5d761c 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -31,15 +31,6 @@ class ConstructCircuitDAG: def __init__(self, dag_builder: DAGBuilder) -> None: self.dag_builder: DAGBuilder = dag_builder - # ================================= - # 1. CORE DISPATCH AND ENTRY POINT - # ================================= - - @singledispatchmethod - def _visit(self, op: Any) -> None: - """Central dispatch method (Visitor Pattern). Routes the operation 'op' - to the specialized handler registered for its type.""" - def construct(self, module: builtin.ModuleOp) -> None: """Constructs the DAG from the module. @@ -54,6 +45,11 @@ def construct(self, module: builtin.ModuleOp) -> None: # 2. IR TRAVERSAL # ======================= + @singledispatchmethod + def _visit(self, op: Any) -> None: + """Central dispatch method (Visitor Pattern). Routes the operation 'op' + to the specialized handler registered for its type.""" + @_visit.register def _operation(self, operation: Operation) -> None: """Visit an xDSL Operation.""" From 77dd502c7cb8765d269a265f42a2fcd0cd7a3899 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 16:27:36 -0500 Subject: [PATCH 076/111] minor cleanup --- .../python_interface/visualization/construct_circuit_dag.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 401a5d761c..92e0febaed 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -41,9 +41,9 @@ def construct(self, module: builtin.ModuleOp) -> None: for op in module.ops: self._visit(op) - # ======================= - # 2. IR TRAVERSAL - # ======================= + # ============= + # IR TRAVERSAL + # ============= @singledispatchmethod def _visit(self, op: Any) -> None: From b4ccd611bf89a3b9efc3ddc1b7d2a3099542de0c Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 17:08:32 -0500 Subject: [PATCH 077/111] refactor the get_ to properties --- .../visualization/dag_builder.py | 11 ++- .../visualization/pydot_dag_builder.py | 15 ++-- .../visualization/test_dag_builder.py | 15 ++-- .../visualization/test_pydot_dag_builder.py | 73 +++---------------- 4 files changed, 35 insertions(+), 79 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index 8a2fff6960..a268e24351 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -13,7 +13,7 @@ # limitations under the License. """File that defines the DAGBuilder abstract base class.""" -from abc import ABC, abstractmethod +from abc import ABC, abstractmethod, abstractproperty from typing import Any, TypeAlias ClusterID: TypeAlias = str @@ -84,8 +84,9 @@ def add_cluster( """ raise NotImplementedError + @property @abstractmethod - def get_nodes(self) -> dict[NodeID, dict[str, Any]]: + def nodes(self) -> dict[NodeID, dict[str, Any]]: """Retrieve the current set of nodes in the graph. Returns: @@ -93,8 +94,9 @@ def get_nodes(self) -> dict[NodeID, dict[str, Any]]: """ raise NotImplementedError + @property @abstractmethod - def get_edges(self) -> list[dict[str, Any]]: + def edges(self) -> list[dict[str, Any]]: """Retrieve the current set of edges in the graph. Returns: @@ -102,8 +104,9 @@ def get_edges(self) -> list[dict[str, Any]]: """ raise NotImplementedError + @property @abstractmethod - def get_clusters(self) -> dict[ClusterID, dict[str, Any]]: + def clusters(self) -> dict[ClusterID, dict[str, Any]]: """Retrieve the current set of clusters in the graph. Returns: diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index e8abda83da..77ee6d9fd7 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -204,29 +204,32 @@ def add_cluster( "attrs": dict(cluster_attrs), } - def get_nodes(self) -> dict[str, dict[str, Any]]: + @property + def nodes(self) -> dict[str, dict[str, Any]]: """Retrieve the current set of nodes in the graph. Returns: nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to it's node information. """ - return self._nodes.copy() + return self._nodes - def get_edges(self) -> list[dict[str, Any]]: + @property + def edges(self) -> list[dict[str, Any]]: """Retrieve the current set of edges in the graph. Returns: edges (list[dict[str, Any]]): A list of edges where each element in the list contains a dictionary of edge information. """ - return self._edges.copy() + return self._edges - def get_clusters(self) -> dict[str, dict[str, Any]]: + @property + def clusters(self) -> dict[str, dict[str, Any]]: """Retrieve the current set of clusters in the graph. Returns: clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to it's cluster information. """ - return self._clusters.copy() + return self._clusters def to_file(self, output_filename: str) -> None: """Save the graph to a file. diff --git a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py index 2ff5b8255f..df3a431ae2 100644 --- a/frontend/test/pytest/python_interface/visualization/test_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_dag_builder.py @@ -52,13 +52,16 @@ def add_cluster( ) -> None: return - def get_nodes(self) -> dict[str, dict[str, Any]]: + @property + def nodes(self) -> dict[str, dict[str, Any]]: return {} - def get_edges(self) -> list[dict[str, Any]]: + @property + def edges(self) -> list[dict[str, Any]]: return [] - def get_clusters(self) -> dict[str, dict[str, Any]]: + @property + def clusters(self) -> dict[str, dict[str, Any]]: return {} def to_file(self, output_filename: str) -> None: @@ -72,9 +75,9 @@ def to_string(self) -> str: node = dag_builder.add_node("0", "node0") edge = dag_builder.add_edge("0", "1") cluster = dag_builder.add_cluster("0") - nodes = dag_builder.get_nodes() - edges = dag_builder.get_edges() - clusters = dag_builder.get_clusters() + nodes = dag_builder.nodes + edges = dag_builder.edges + clusters = dag_builder.clusters render = dag_builder.to_file("test.png") string = dag_builder.to_string() diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 761be2da21..4833c0cc2c 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -243,18 +243,18 @@ def test_add_cluster_with_attrs(self): assert cluster2.get("fontname") == "Helvetica" -class TestGetMethods: - """Tests the get_* methods.""" +class TestProperties: + """Tests the properties.""" - def test_get_nodes(self): - """Tests that get_nodes works.""" + def test_nodes(self): + """Tests that nodes works.""" dag_builder = PyDotDAGBuilder() dag_builder.add_node("0", "node0", fillcolor="red") dag_builder.add_cluster("c0") dag_builder.add_node("1", "node1", cluster_id="c0") - nodes = dag_builder.get_nodes() + nodes = dag_builder.nodes assert len(nodes) == 2 assert len(nodes["0"]) == 4 @@ -268,32 +268,15 @@ def test_get_nodes(self): assert nodes["1"]["label"] == "node1" assert nodes["1"]["cluster_id"] == "c0" - def test_get_nodes_doesnt_mutate(self): - """Tests that get_nodes doesn't mutate state""" - - dag_builder = PyDotDAGBuilder() - - dag_builder.add_node("0", "node0") - - old_nodes = dag_builder.get_nodes() - - dag_builder.add_node("1", "node1") - - new_nodes = dag_builder.get_nodes() - - assert old_nodes is not new_nodes - assert len(old_nodes) == 1 - assert len(new_nodes) == 2 - - def test_get_edges(self): - """Tests that get_edges works.""" + def test_edges(self): + """Tests that edges works.""" dag_builder = PyDotDAGBuilder() dag_builder.add_node("0", "node0") dag_builder.add_node("1", "node1") dag_builder.add_edge("0", "1", penwidth=10) - edges = dag_builder.get_edges() + edges = dag_builder.edges assert len(edges) == 1 @@ -301,37 +284,18 @@ def test_get_edges(self): assert edges[0]["to_id"] == "1" assert edges[0]["attrs"]["penwidth"] == 10 - def test_get_edges_doesnt_mutate(self): - """Tests that get_edges doesn't mutated.""" - - dag_builder = PyDotDAGBuilder() - dag_builder.add_node("0", "node0") - dag_builder.add_node("1", "node1") - dag_builder.add_edge("0", "1") - - old_edges = dag_builder.get_edges() - - dag_builder.add_node("2", "node2") - dag_builder.add_edge("1", "2") - - new_edges = dag_builder.get_edges() - - assert old_edges is not new_edges - assert len(old_edges) == 1 - assert len(new_edges) == 2 - def test_get_clusters(self): """Tests that get_clusters works.""" dag_builder = PyDotDAGBuilder() dag_builder.add_cluster("0", "my_info_node", label="my_cluster", penwidth=10) - clusters = dag_builder.get_clusters() + clusters = dag_builder.clusters dag_builder.add_cluster( "1", "my_other_info_node", cluster_id="0", label="my_nested_cluster" ) - clusters = dag_builder.get_clusters() + clusters = dag_builder.clusters assert len(clusters) == 2 assert len(clusters["0"]) == 5 @@ -347,23 +311,6 @@ def test_get_clusters(self): assert clusters["1"]["node_label"] == "my_other_info_node" assert clusters["1"]["cluster_id"] == "0" - def test_get_clusters_doesnt_mutate(self): - """Tests that get_clusters doesn't mutate state""" - - dag_builder = PyDotDAGBuilder() - - dag_builder.add_cluster("0") - - old_clusters = dag_builder.get_clusters() - - dag_builder.add_cluster("1") - - new_clusters = dag_builder.get_clusters() - - assert old_clusters is not new_clusters - assert len(old_clusters) == 1 - assert len(new_clusters) == 2 - class TestOutput: """Test that the graph can be outputted correctly.""" From d14c15bd52da7cda2fb982d1b6f3dcc432e32b69 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 17:13:00 -0500 Subject: [PATCH 078/111] update fake dag builder --- .../visualization/test_construct_circuit_dag.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index cd44eef0cb..8e6bdaef7c 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -76,14 +76,17 @@ def add_cluster( "attrs": attrs, } - def get_nodes(self): - return self._nodes.copy() + @property + def nodes(self): + return self._nodes - def get_edges(self): - return self._edges.copy() + @property + def edges(self): + return self._edges - def get_clusters(self): - return self._clusters.copy() + @property + def clusters(self): + return self._clusters def to_file(self, output_filename): pass From 0e0ddfdc77044d2d35ed4aa0c21d6592221252c1 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 09:24:58 -0500 Subject: [PATCH 079/111] attempt to get rid of _subgraphs --- .../visualization/pydot_dag_builder.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 77ee6d9fd7..9a2922f6ec 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -53,10 +53,6 @@ def __init__( self.graph: pydot.Dot = pydot.Dot( graph_type="digraph", rankdir="TB", compound="true", strict=True ) - # Create cache for easy look-up - # TODO: Get rid of this and use self._clusters if possible - self._subgraphs: dict[str, pydot.Graph] = {} - self._subgraphs["__base__"] = self.graph # Internal state for graph structure self._nodes: dict[str, dict[str, Any]] = {} @@ -115,9 +111,14 @@ def add_node( # Use ChainMap so you don't need to construct a new dictionary node_attrs: ChainMap = ChainMap(attrs, self._default_node_attrs) node = pydot.Node(id, label=label, **node_attrs) - cluster_id = "__base__" if cluster_id is None else cluster_id - self._subgraphs[cluster_id].add_node(node) + # Add node to cluster + if cluster_id is None: + self.graph.add_node(node) + else: + # Use cluster ID to look up the subgraph + assert len(self.graph.get_subgraph(cluster_id)) == 1 + self.graph.get_subgraph(cluster_id)[0].add_node(node) self._nodes[id] = { "id": id, @@ -191,10 +192,13 @@ def add_cluster( cluster.add_subgraph(rank_subgraph) cluster.add_node(node) - self._subgraphs[id] = cluster - - cluster_id = "__base__" if cluster_id is None else cluster_id - self._subgraphs[cluster_id].add_subgraph(cluster) + # Add cluster to parent cluster + if cluster_id is None: + self.graph.add_subgraph(cluster) + else: + # Use cluster ID to look up the subgraph + assert len(self.graph.get_subgraph(cluster_id)) == 1 + self.graph.get_subgraph(cluster_id)[0].add_subgraph(cluster) self._clusters[id] = { "id": id, From c41adb9c16b2cb8733ba6aa3ceb4879412605e66 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 09:25:47 -0500 Subject: [PATCH 080/111] clean-up test --- .../python_interface/visualization/test_pydot_dag_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 4833c0cc2c..36f25ea990 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -284,8 +284,8 @@ def test_edges(self): assert edges[0]["to_id"] == "1" assert edges[0]["attrs"]["penwidth"] == 10 - def test_get_clusters(self): - """Tests that get_clusters works.""" + def test_clusters(self): + """Tests that clusters property works.""" dag_builder = PyDotDAGBuilder() dag_builder.add_cluster("0", "my_info_node", label="my_cluster", penwidth=10) From e8035430b87243871fc40365b1687d9003115bb6 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 09:29:32 -0500 Subject: [PATCH 081/111] rename __base__ to None --- .../python_interface/visualization/test_pydot_dag_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 36f25ea990..436ebc8319 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -261,7 +261,7 @@ def test_nodes(self): assert nodes["0"]["id"] == "0" assert nodes["0"]["label"] == "node0" - assert nodes["0"]["cluster_id"] == "__base__" + assert nodes["0"]["cluster_id"] == None assert nodes["0"]["attrs"]["fillcolor"] == "red" assert nodes["1"]["id"] == "1" @@ -302,7 +302,7 @@ def test_clusters(self): assert clusters["0"]["id"] == "0" assert clusters["0"]["cluster_label"] == "my_cluster" assert clusters["0"]["node_label"] == "my_info_node" - assert clusters["0"]["cluster_id"] == "__base__" + assert clusters["0"]["cluster_id"] == None assert clusters["0"]["attrs"]["penwidth"] == 10 assert len(clusters["1"]) == 5 From 085ec57f34919d38792e8f3f362663576bd6727e Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 09:59:35 -0500 Subject: [PATCH 082/111] clean-up --- .../visualization/pydot_dag_builder.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 9a2922f6ec..f034131a91 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -22,6 +22,7 @@ has_pydot = True try: import pydot + from pydot import Cluster, Dot, Edge, Node, Subgraph except ImportError: has_pydot = False @@ -50,7 +51,7 @@ def __init__( # - rankdir="TB": Set layout direction from Top to Bottom. # - compound="true": Allow edges to connect directly to clusters/subgraphs. # - strict=True: Prevent duplicate edges (e.g., A -> B added twice). - self.graph: pydot.Dot = pydot.Dot( + self.graph: Dot = Dot( graph_type="digraph", rankdir="TB", compound="true", strict=True ) @@ -110,15 +111,16 @@ def add_node( """ # Use ChainMap so you don't need to construct a new dictionary node_attrs: ChainMap = ChainMap(attrs, self._default_node_attrs) - node = pydot.Node(id, label=label, **node_attrs) + node = Node(id, label=label, **node_attrs) # Add node to cluster if cluster_id is None: self.graph.add_node(node) else: # Use cluster ID to look up the subgraph - assert len(self.graph.get_subgraph(cluster_id)) == 1 - self.graph.get_subgraph(cluster_id)[0].add_node(node) + parent_clusters = self.graph.get_subgraph(cluster_id)[0] + assert len(parent_clusters) == 1 + parent_clusters[0].add_node(node) self._nodes[id] = { "id": id, @@ -138,7 +140,8 @@ def add_edge(self, from_id: str, to_id: str, **attrs: Any) -> None: """ # Use ChainMap so you don't need to construct a new dictionary edge_attrs: ChainMap = ChainMap(attrs, self._default_edge_attrs) - edge = pydot.Edge(from_id, to_id, **edge_attrs) + edge = Edge(from_id, to_id, **edge_attrs) + self.graph.add_edge(edge) self._edges.append( @@ -166,7 +169,7 @@ def add_cluster( """ # Use ChainMap so you don't need to construct a new dictionary cluster_attrs: ChainMap = ChainMap(attrs, self._default_cluster_attrs) - cluster = pydot.Cluster(graph_name=id, **cluster_attrs) + cluster = Cluster(graph_name=id, **cluster_attrs) # Puts the label in a node within the cluster. # Ensures that any edges connecting nodes through the cluster @@ -179,8 +182,8 @@ def add_cluster( # └───────────┘ if node_label: node_id = f"{cluster_id}_info_node" - rank_subgraph = pydot.Subgraph() - node = pydot.Node( + rank_subgraph = Subgraph() + node = Node( node_id, label=node_label, shape="rectangle", @@ -197,8 +200,9 @@ def add_cluster( self.graph.add_subgraph(cluster) else: # Use cluster ID to look up the subgraph - assert len(self.graph.get_subgraph(cluster_id)) == 1 - self.graph.get_subgraph(cluster_id)[0].add_subgraph(cluster) + parent_clusters = self.graph.get_subgraph(cluster_id)[0] + assert len(parent_clusters) == 1 + parent_clusters[0].add_subgraph(cluster) self._clusters[id] = { "id": id, From c08a84cc82b7ab58a215b977547e3922f0b134b2 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 10:01:27 -0500 Subject: [PATCH 083/111] whoops --- .../python_interface/visualization/pydot_dag_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index f034131a91..aad95a2faf 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -118,7 +118,7 @@ def add_node( self.graph.add_node(node) else: # Use cluster ID to look up the subgraph - parent_clusters = self.graph.get_subgraph(cluster_id)[0] + parent_clusters = self.graph.get_subgraph(cluster_id) assert len(parent_clusters) == 1 parent_clusters[0].add_node(node) @@ -200,7 +200,7 @@ def add_cluster( self.graph.add_subgraph(cluster) else: # Use cluster ID to look up the subgraph - parent_clusters = self.graph.get_subgraph(cluster_id)[0] + parent_clusters = self.graph.get_subgraph(cluster_id) assert len(parent_clusters) == 1 parent_clusters[0].add_subgraph(cluster) From edc20765161feb4af228518072b7f32f6f7aa0b7 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 10:03:37 -0500 Subject: [PATCH 084/111] add cluster_ prefix --- .../python_interface/visualization/pydot_dag_builder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index aad95a2faf..9e617460d0 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -118,7 +118,7 @@ def add_node( self.graph.add_node(node) else: # Use cluster ID to look up the subgraph - parent_clusters = self.graph.get_subgraph(cluster_id) + parent_clusters = self.graph.get_subgraph("cluster_"+cluster_id) assert len(parent_clusters) == 1 parent_clusters[0].add_node(node) @@ -200,7 +200,7 @@ def add_cluster( self.graph.add_subgraph(cluster) else: # Use cluster ID to look up the subgraph - parent_clusters = self.graph.get_subgraph(cluster_id) + parent_clusters = self.graph.get_subgraph("cluster_"+cluster_id) assert len(parent_clusters) == 1 parent_clusters[0].add_subgraph(cluster) From 4c71876b8e2f0bbe5eddb361b14a1756c36d0af1 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 10:22:12 -0500 Subject: [PATCH 085/111] add debug string --- .../visualization/pydot_dag_builder.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 9e617460d0..4960a1507a 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -118,8 +118,10 @@ def add_node( self.graph.add_node(node) else: # Use cluster ID to look up the subgraph - parent_clusters = self.graph.get_subgraph("cluster_"+cluster_id) - assert len(parent_clusters) == 1 + parent_clusters = self.graph.get_subgraph("cluster_" + cluster_id) + assert len(parent_clusters) == 1, ( + f"Found {len(parent_clusters)} parent clusters with id {'cluster_' + cluster_id}" + ) parent_clusters[0].add_node(node) self._nodes[id] = { @@ -169,7 +171,7 @@ def add_cluster( """ # Use ChainMap so you don't need to construct a new dictionary cluster_attrs: ChainMap = ChainMap(attrs, self._default_cluster_attrs) - cluster = Cluster(graph_name=id, **cluster_attrs) + cluster = Cluster(id, **cluster_attrs) # Puts the label in a node within the cluster. # Ensures that any edges connecting nodes through the cluster @@ -200,8 +202,10 @@ def add_cluster( self.graph.add_subgraph(cluster) else: # Use cluster ID to look up the subgraph - parent_clusters = self.graph.get_subgraph("cluster_"+cluster_id) - assert len(parent_clusters) == 1 + parent_clusters = self.graph.get_subgraph("cluster_" + cluster_id) + assert len(parent_clusters) == 1, ( + f"Found {len(parent_clusters)} parent clusters with id {'cluster_' + cluster_id}" + ) parent_clusters[0].add_subgraph(cluster) self._clusters[id] = { From f1e5849b202f7af05a99ad238a614952d532e54b Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 10:32:42 -0500 Subject: [PATCH 086/111] bring back cache --- .../visualization/pydot_dag_builder.py | 21 +++++++------------ .../visualization/test_pydot_dag_builder.py | 8 +++---- 2 files changed, 11 insertions(+), 18 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 4960a1507a..740b00f378 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -22,7 +22,7 @@ has_pydot = True try: import pydot - from pydot import Cluster, Dot, Edge, Node, Subgraph + from pydot import Cluster, Dot, Edge, Graph, Node, Subgraph except ImportError: has_pydot = False @@ -55,6 +55,9 @@ def __init__( graph_type="digraph", rankdir="TB", compound="true", strict=True ) + # Use internal cache that maps cluster ID to actual pydot (Dot or Cluster) object + self._subgraph_cache: dict[str, Graph] = {} + # Internal state for graph structure self._nodes: dict[str, dict[str, Any]] = {} self._edges: list[dict[str, Any]] = [] @@ -117,12 +120,7 @@ def add_node( if cluster_id is None: self.graph.add_node(node) else: - # Use cluster ID to look up the subgraph - parent_clusters = self.graph.get_subgraph("cluster_" + cluster_id) - assert len(parent_clusters) == 1, ( - f"Found {len(parent_clusters)} parent clusters with id {'cluster_' + cluster_id}" - ) - parent_clusters[0].add_node(node) + parent_cluster = self._subgraph_cache[cluster_id].add_node(node) self._nodes[id] = { "id": id, @@ -197,16 +195,11 @@ def add_cluster( cluster.add_subgraph(rank_subgraph) cluster.add_node(node) - # Add cluster to parent cluster + # Add node to cluster if cluster_id is None: self.graph.add_subgraph(cluster) else: - # Use cluster ID to look up the subgraph - parent_clusters = self.graph.get_subgraph("cluster_" + cluster_id) - assert len(parent_clusters) == 1, ( - f"Found {len(parent_clusters)} parent clusters with id {'cluster_' + cluster_id}" - ) - parent_clusters[0].add_subgraph(cluster) + parent_cluster = self._subgraph_cache[cluster_id].add_node(cluster) self._clusters[id] = { "id": id, diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 436ebc8319..6474b42d68 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -119,13 +119,13 @@ def test_add_cluster_to_parent_graph(self): # Level 0 (Root): Adds cluster on top of base graph dag_builder.add_node("n_root", "node_root") - dag_builder.add_cluster("c0") - # Level 1 (Inside c0): Add node on outer cluster and create new cluster on top + # Level 1 (c0): Add node on outer cluster + dag_builder.add_cluster("c0") dag_builder.add_node("n_outer", "node_outer", cluster_id="c0") - dag_builder.add_cluster("c1", cluster_id="c0") - # Level 2 (Inside c1): Add node on second cluster + # Level 2 (c1): Add node on inner cluster + dag_builder.add_cluster("c1", cluster_id="c0") dag_builder.add_node("n_inner", "node_inner", cluster_id="c1") root_graph = dag_builder.graph From 3c9ca3ae18112853ad21a0a83312205a501c7073 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 10:34:27 -0500 Subject: [PATCH 087/111] add good dev comment --- .../python_interface/visualization/pydot_dag_builder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 740b00f378..d620b94987 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -56,6 +56,8 @@ def __init__( ) # Use internal cache that maps cluster ID to actual pydot (Dot or Cluster) object + # NOTE: This is needed so we don't need to traverse the graph to find the relevant + # cluster object to modify self._subgraph_cache: dict[str, Graph] = {} # Internal state for graph structure From bc54dbf2fd54d55f3c1437527785f020e414fc9b Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 26 Nov 2025 10:44:03 -0500 Subject: [PATCH 088/111] Apply suggestion from @andrijapau --- .../python_interface/visualization/pydot_dag_builder.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index d620b94987..008bdd3f7b 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -197,6 +197,9 @@ def add_cluster( cluster.add_subgraph(rank_subgraph) cluster.add_node(node) + # Record new cluster + self._subgraph_cache[id] = cluster + # Add node to cluster if cluster_id is None: self.graph.add_subgraph(cluster) From ee57f404802930981d675bdbae0b3aa82c96b52e Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 26 Nov 2025 10:44:17 -0500 Subject: [PATCH 089/111] Apply suggestion from @andrijapau --- .../python_interface/visualization/pydot_dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 008bdd3f7b..ed690eb2ce 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -204,7 +204,7 @@ def add_cluster( if cluster_id is None: self.graph.add_subgraph(cluster) else: - parent_cluster = self._subgraph_cache[cluster_id].add_node(cluster) + parent_cluster = self._subgraph_cache[cluster_id].add_subgraph(cluster) self._clusters[id] = { "id": id, From c34185874df2365396001ceefd024036ced0ab16 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 10:45:07 -0500 Subject: [PATCH 090/111] whoops --- .../python_interface/visualization/pydot_dag_builder.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index ed690eb2ce..ff15f32f43 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -122,7 +122,7 @@ def add_node( if cluster_id is None: self.graph.add_node(node) else: - parent_cluster = self._subgraph_cache[cluster_id].add_node(node) + self._subgraph_cache[cluster_id].add_node(node) self._nodes[id] = { "id": id, @@ -199,12 +199,12 @@ def add_cluster( # Record new cluster self._subgraph_cache[id] = cluster - + # Add node to cluster if cluster_id is None: self.graph.add_subgraph(cluster) else: - parent_cluster = self._subgraph_cache[cluster_id].add_subgraph(cluster) + self._subgraph_cache[cluster_id].add_subgraph(cluster) self._clusters[id] = { "id": id, From 0460cac9ad28ce6e0e17382b6cdf74ef00ab131c Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 11:03:51 -0500 Subject: [PATCH 091/111] refactor singledispatch --- .../visualization/construct_circuit_dag.py | 23 +++++++------------ 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 92e0febaed..0b0f960064 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -39,31 +39,24 @@ def construct(self, module: builtin.ModuleOp) -> None: """ for op in module.ops: - self._visit(op) + self._visit_operation(op) # ============= # IR TRAVERSAL # ============= @singledispatchmethod - def _visit(self, op: Any) -> None: - """Central dispatch method (Visitor Pattern). Routes the operation 'op' - to the specialized handler registered for its type.""" - - @_visit.register - def _operation(self, operation: Operation) -> None: - """Visit an xDSL Operation.""" + def _visit_operation(self, operation: Operation) -> None: + """Visit an xDSL Operation. Default to visiting each region contained in the operation.""" for region in operation.regions: - self._visit(region) + self._visit_region(region) - @_visit.register - def _region(self, region: Region) -> None: + def _visit_region(self, region: Region) -> None: """Visit an xDSL Region operation.""" for block in region.blocks: - self._visit(block) + self._visit_block(block) - @_visit.register - def _block(self, block: Block) -> None: + def _visit_block(self, block: Block) -> None: """Visit an xDSL Block operation, dispatching handling for each contained Operation.""" for op in block.ops: - self._visit(op) + self._visit_operation(op) From 0cbe9afa9fcfde84518f3a8c8c2efedaddeb5821 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 26 Nov 2025 11:47:01 -0500 Subject: [PATCH 092/111] Update frontend/catalyst/python_interface/visualization/dag_builder.py Co-authored-by: Mudit Pandey <18223836+mudit2812@users.noreply.github.com> --- frontend/catalyst/python_interface/visualization/dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index a268e24351..4cfe44854a 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -13,7 +13,7 @@ # limitations under the License. """File that defines the DAGBuilder abstract base class.""" -from abc import ABC, abstractmethod, abstractproperty +from abc import ABC, abstractmethod from typing import Any, TypeAlias ClusterID: TypeAlias = str From e7523295c075a55626f030192624bd0af6dc64f9 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 12:16:58 -0500 Subject: [PATCH 093/111] add more details to docstring --- .../visualization/construct_circuit_dag.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 0b0f960064..6b7500562d 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -24,8 +24,18 @@ class ConstructCircuitDAG: - """A tool that traverses an xDSL module and constructs a Directed Acyclic Graph (DAG) + """Utility tool following the director pattern to build a DAG representation of a compiled quantum program. + + This tool traverses an xDSL module and constructs a Directed Acyclic Graph (DAG) of it's quantum program using an injected DAGBuilder instance. This tool does not mutate the xDSL module. + + **Example** + + >>> builder = PyDotDAGBuilder() + >>> director = ConstructCircuitDAG(builder) + >>> director.construct(module) + >>> director.dag_builder.to_string() + ... """ def __init__(self, dag_builder: DAGBuilder) -> None: From 86b566287f755c89305d298dddec10e326c738de Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 26 Nov 2025 13:47:07 -0500 Subject: [PATCH 094/111] Update frontend/catalyst/python_interface/visualization/dag_builder.py Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- frontend/catalyst/python_interface/visualization/dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index 4cfe44854a..ea1cb5438d 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -90,7 +90,7 @@ def nodes(self) -> dict[NodeID, dict[str, Any]]: """Retrieve the current set of nodes in the graph. Returns: - nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to it's node information. + nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to its node information. """ raise NotImplementedError From 9c26efdc019ba3b4c258510827eebc8157d0b298 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 26 Nov 2025 13:47:14 -0500 Subject: [PATCH 095/111] Update frontend/catalyst/python_interface/visualization/pydot_dag_builder.py Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- .../python_interface/visualization/pydot_dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index ff15f32f43..06e9e00fb5 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -219,7 +219,7 @@ def nodes(self) -> dict[str, dict[str, Any]]: """Retrieve the current set of nodes in the graph. Returns: - nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to it's node information. + nodes (dict[str, dict[str, Any]]): A dictionary that maps the node's ID to its node information. """ return self._nodes From 81a9aa6969c8c937fa17aa55009fb327e0e96cd7 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 26 Nov 2025 13:47:28 -0500 Subject: [PATCH 096/111] Update frontend/catalyst/python_interface/visualization/pydot_dag_builder.py Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- .../python_interface/visualization/pydot_dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 06e9e00fb5..7fd23660cb 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -183,7 +183,7 @@ def add_cluster( # │ │ # └───────────┘ if node_label: - node_id = f"{cluster_id}_info_node" + node_id = f"{id}_info_node" rank_subgraph = Subgraph() node = Node( node_id, From ab06276ac071751e92e236c56036ba5e3292fdf8 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 26 Nov 2025 13:47:44 -0500 Subject: [PATCH 097/111] Update frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- .../python_interface/visualization/test_pydot_dag_builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 6474b42d68..f84fa80773 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -17,7 +17,6 @@ import pytest -from frontend.catalyst.python_interface.visualization import dag_builder pydot = pytest.importorskip("pydot") pytestmark = pytest.mark.usefixtures("requires_xdsl") From 3e4102b1139ff222644b95421c9e954aa0241f11 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 26 Nov 2025 13:48:08 -0500 Subject: [PATCH 098/111] Update frontend/catalyst/python_interface/visualization/pydot_dag_builder.py Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- .../python_interface/visualization/pydot_dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 7fd23660cb..5fd1cf7d74 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -237,7 +237,7 @@ def clusters(self) -> dict[str, dict[str, Any]]: """Retrieve the current set of clusters in the graph. Returns: - clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to it's cluster information. + clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to its cluster information. """ return self._clusters From 194f14a55d86cb713584a7ac722105cb2f12bd48 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 26 Nov 2025 13:48:16 -0500 Subject: [PATCH 099/111] Update frontend/catalyst/python_interface/visualization/dag_builder.py Co-authored-by: Mehrdad Malek <39844030+mehrdad2m@users.noreply.github.com> --- frontend/catalyst/python_interface/visualization/dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/dag_builder.py b/frontend/catalyst/python_interface/visualization/dag_builder.py index ea1cb5438d..70c5806616 100644 --- a/frontend/catalyst/python_interface/visualization/dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/dag_builder.py @@ -110,7 +110,7 @@ def clusters(self) -> dict[ClusterID, dict[str, Any]]: """Retrieve the current set of clusters in the graph. Returns: - clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to it's cluster information. + clusters (dict[str, dict[str, Any]]): A dictionary that maps the cluster's ID to its cluster information. """ raise NotImplementedError From 685842c5360c91705d08896ab28a535fb2dba7a0 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 14:02:38 -0500 Subject: [PATCH 100/111] fix fakebackend --- .../visualization/test_construct_circuit_dag.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 8e6bdaef7c..4a98739c5a 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -44,7 +44,6 @@ def __init__(self): self._clusters = {} def add_node(self, id, label, cluster_id=None, **attrs) -> None: - cluster_id = "__base__" if cluster_id is None else cluster_id self._nodes[id] = { "id": id, "label": label, @@ -55,8 +54,8 @@ def add_node(self, id, label, cluster_id=None, **attrs) -> None: def add_edge(self, from_id: str, to_id: str, **attrs) -> None: self._edges.append( { - "from": from_id, - "to": to_id, + "from_id": from_id, + "to_id": to_id, "attrs": attrs, } ) @@ -68,10 +67,10 @@ def add_cluster( cluster_id=None, **attrs, ) -> None: - cluster_id = "__base__" if cluster_id is None else cluster_id self._clusters[id] = { "id": id, - "label": node_label, + "node_label": node_label, + "cluster_label": attrs.get("label"), "cluster_id": cluster_id, "attrs": attrs, } From 8c64d8187698a2288a0d9797d9046041c7251b6e Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 14:04:05 -0500 Subject: [PATCH 101/111] isort --- .../python_interface/visualization/test_pydot_dag_builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index f84fa80773..b87b283903 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -17,7 +17,6 @@ import pytest - pydot = pytest.importorskip("pydot") pytestmark = pytest.mark.usefixtures("requires_xdsl") # pylint: disable=wrong-import-position From f666a9fe8876d4feac6e3a95414eb24c19ce04ab Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 15:20:10 -0500 Subject: [PATCH 102/111] add exceptions check --- .../visualization/pydot_dag_builder.py | 13 +++++ .../visualization/test_pydot_dag_builder.py | 49 ++++++++++++++++++- 2 files changed, 61 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index 5fd1cf7d74..e14241d073 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -114,6 +114,9 @@ def add_node( **attrs (Any): Any additional styling keyword arguments. """ + if id in self.nodes: + raise ValueError(f"Node ID {id} already present in graph.") + # Use ChainMap so you don't need to construct a new dictionary node_attrs: ChainMap = ChainMap(attrs, self._default_node_attrs) node = Node(id, label=label, **node_attrs) @@ -140,6 +143,13 @@ def add_edge(self, from_id: str, to_id: str, **attrs: Any) -> None: **attrs (Any): Any additional styling keyword arguments. """ + if from_id == to_id: + raise ValueError("Edges must connect two unique IDs.") + if from_id not in self.nodes: + raise ValueError("Source is not found in the graph.") + if to_id not in self.nodes: + raise ValueError("Destination is not found in the graph.") + # Use ChainMap so you don't need to construct a new dictionary edge_attrs: ChainMap = ChainMap(attrs, self._default_edge_attrs) edge = Edge(from_id, to_id, **edge_attrs) @@ -169,6 +179,9 @@ def add_cluster( **attrs (Any): Any additional styling keyword arguments. """ + if id in self.clusters: + raise ValueError(f"Cluster ID {id} already present in graph.") + # Use ChainMap so you don't need to construct a new dictionary cluster_attrs: ChainMap = ChainMap(attrs, self._default_cluster_attrs) cluster = Cluster(id, **cluster_attrs) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index b87b283903..5219889655 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -40,6 +40,53 @@ def test_initialization_defaults(): assert dag_builder.graph.obj_dict["strict"] is True +class TestExceptions: + """Tests the various exceptions defined in the class.""" + + def test_double_node_id(self): + """Tests that a ValueError is raised for duplicate nodes.""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + with pytest.raises(ValueError, match="Node ID 0 already present in graph."): + dag_builder.add_node("0", "node1") + + def test_edge_duplicate_source_destination(self): + """Tests that a ValueError is raised when an edge is created with the + same source and destination""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + with pytest.raises(ValueError, match="Edges must connect two uniques IDs."): + dag_builder.add_edge("0", "0") + + def test_edge_missing_ids(self): + """Tests that an error is raised if IDs are missing.""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("0", "node0") + with pytest.raises(ValueError, match="Destination is not found in the graph."): + dag_builder.add_edge("0", "1") + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_node("1", "node1") + with pytest.raises(ValueError, match="Source is not found in the graph."): + dag_builder.add_edge("0", "1") + + def test_duplicate_cluster_id(self): + """Tests that an exception is raised if an ID is already present.""" + + dag_builder = PyDotDAGBuilder() + + dag_builder.add_cluster("0") + with pytest.raises(ValueError, match="Cluster ID 0 already present in graph."): + dag_builder.add_cluster("0") + + class TestAddMethods: """Test that elements can be added to the graph.""" @@ -118,7 +165,7 @@ def test_add_cluster_to_parent_graph(self): # Level 0 (Root): Adds cluster on top of base graph dag_builder.add_node("n_root", "node_root") - # Level 1 (c0): Add node on outer cluster + # Level 1 (c0): Add node on outer cluster dag_builder.add_cluster("c0") dag_builder.add_node("n_outer", "node_outer", cluster_id="c0") From e28b2b71acbdc4aa3b7ae6cf402bdf9b43348526 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 15:23:43 -0500 Subject: [PATCH 103/111] add better documentation --- .../visualization/pydot_dag_builder.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py index e14241d073..ea01ecdf21 100644 --- a/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py +++ b/frontend/catalyst/python_interface/visualization/pydot_dag_builder.py @@ -113,6 +113,9 @@ def add_node( cluster_id (str | None): Optional ID of the cluster this node belongs to. **attrs (Any): Any additional styling keyword arguments. + Raises: + ValueError: Node ID is already present in the graph. + """ if id in self.nodes: raise ValueError(f"Node ID {id} already present in graph.") @@ -142,6 +145,11 @@ def add_edge(self, from_id: str, to_id: str, **attrs: Any) -> None: to_id (str): The unique ID of the destination node. **attrs (Any): Any additional styling keyword arguments. + Raises: + ValueError: Source and destination have the same ID + ValueError: Source is not found in the graph. + ValueError: Destination is not found in the graph. + """ if from_id == to_id: raise ValueError("Edges must connect two unique IDs.") @@ -178,6 +186,8 @@ def add_cluster( cluster_id (str | None): Optional ID of the cluster this cluster belongs to. If `None`, the cluster will be positioned on the base graph. **attrs (Any): Any additional styling keyword arguments. + Raises: + ValueError: Cluster ID is already present in the graph. """ if id in self.clusters: raise ValueError(f"Cluster ID {id} already present in graph.") From 952fd7f21c05911bd1398e4c5935724d138b2501 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 15:36:35 -0500 Subject: [PATCH 104/111] fix typo --- .../python_interface/visualization/test_pydot_dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 5219889655..197be4af70 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -59,7 +59,7 @@ def test_edge_duplicate_source_destination(self): dag_builder = PyDotDAGBuilder() dag_builder.add_node("0", "node0") - with pytest.raises(ValueError, match="Edges must connect two uniques IDs."): + with pytest.raises(ValueError, match="Edges must connect two unique IDs."): dag_builder.add_edge("0", "0") def test_edge_missing_ids(self): From 8f2dc98e0f987edef577902d7a3bd7303fb23399 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Wed, 26 Nov 2025 15:38:59 -0500 Subject: [PATCH 105/111] Apply suggestion from @andrijapau --- .../python_interface/visualization/test_pydot_dag_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py index 197be4af70..d9a17f5e3b 100644 --- a/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py +++ b/frontend/test/pytest/python_interface/visualization/test_pydot_dag_builder.py @@ -43,7 +43,7 @@ def test_initialization_defaults(): class TestExceptions: """Tests the various exceptions defined in the class.""" - def test_double_node_id(self): + def test_duplicate_node_ids(self): """Tests that a ValueError is raised for duplicate nodes.""" dag_builder = PyDotDAGBuilder() From ffc9726bef6decc285fc10b7e4b1f4abfebbd4e5 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 26 Nov 2025 17:10:50 -0500 Subject: [PATCH 106/111] fix naming --- .../visualization/test_construct_circuit_dag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 4a98739c5a..14b7892e6b 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -47,7 +47,7 @@ def add_node(self, id, label, cluster_id=None, **attrs) -> None: self._nodes[id] = { "id": id, "label": label, - "cluster_id": cluster_id, + "parent_cluster_id": cluster_id, "attrs": attrs, } @@ -71,7 +71,7 @@ def add_cluster( "id": id, "node_label": node_label, "cluster_label": attrs.get("label"), - "cluster_id": cluster_id, + "parent_cluster_id": cluster_id, "attrs": attrs, } From e675299d9d01d6dc10e17955908e5650145ab217 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 27 Nov 2025 18:36:04 -0500 Subject: [PATCH 107/111] feat: add frontend draw_graph API --- .../python_interface/visualization/draw.py | 84 ++++++++++++++++++- 1 file changed, 83 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/draw.py b/frontend/catalyst/python_interface/visualization/draw.py index ac26977404..90f3984fa3 100644 --- a/frontend/catalyst/python_interface/visualization/draw.py +++ b/frontend/catalyst/python_interface/visualization/draw.py @@ -15,10 +15,13 @@ from __future__ import annotations +import io import warnings from functools import wraps from typing import TYPE_CHECKING +import matplotlib.image as mpimg +import matplotlib.pyplot as plt from pennylane.tape import QuantumScript from catalyst import qjit @@ -26,6 +29,8 @@ from catalyst.python_interface.compiler import Compiler from .collector import QMLCollector +from .construct_circuit_dag import ConstructCircuitDAG +from .pydot_dag_builder import PyDotDAGBuilder if TYPE_CHECKING: from pennylane.typing import Callable @@ -79,7 +84,9 @@ def _draw_callback(previous_pass, module, next_pass, pass_level=0): collector = QMLCollector(module) ops, meas = collector.collect() tape = QuantumScript(ops, meas) - pass_name = pass_instance.name if hasattr(pass_instance, "name") else pass_instance + pass_name = ( + pass_instance.name if hasattr(pass_instance, "name") else pass_instance + ) cache[pass_level] = ( tape.draw(show_matrices=False), pass_name if pass_level else "No transforms", @@ -103,3 +110,78 @@ def wrapper(*args, **kwargs): return cache.get(level, cache[max(cache.keys())])[0] return wrapper + + +def draw_graph(qnode: QNode, *, level: None | int = None) -> Callable: + """ + Draw the QNode at the specified level. + + This function can be used to visualize the QNode at different stages of the transformation + pipeline when xDSL or Catalyst compilation passes are applied. + If the specified level is not available, the highest level will be used as a fallback. + + The provided QNode is assumed to be decorated with compilation passes. + If no passes are applied, the original QNode is visualized. + + Args: + qnode (.QNode): the input QNode that is to be visualized. The QNode is assumed to be + compiled with ``qjit``. + level (None | int): the level of transformation to visualize. If `None`, the final + level is visualized. + + + Returns: + Callable: A wrapper function that visualizes the QNode at the specified level. + + """ + cache: dict[int, tuple[str, str]] = _cache_store.setdefault(qnode, {}) + + def _draw_callback(previous_pass, module, next_pass, pass_level=0): + """Callback function for circuit drawing.""" + + pass_instance = previous_pass if previous_pass else next_pass + utility = ConstructCircuitDAG(PyDotDAGBuilder()) + utility.construct(module) + svg_str = utility.dag_builder.graph.create_svg(prog="dot") + pass_name = ( + pass_instance.name if hasattr(pass_instance, "name") else pass_instance + ) + cache[pass_level] = ( + svg_str, + pass_name if pass_level else "No transforms", + ) + + @wraps(qnode) + def wrapper(*args, **kwargs): + if args or kwargs: + warnings.warn( + "The `draw` function does not yet support dynamic arguments.\n" + "To visualize the circuit with dynamic parameters or wires, please use the\n" + "`compiler.python_compiler.visualization.generate_mlir_graph` function instead.", + UserWarning, + ) + mlir_module = _get_mlir_module(qnode, args, kwargs) + Compiler.run(mlir_module, callback=_draw_callback) + + if not cache: + return None + + # Retrieve the SVG string (or the high-DPI PNG string) + image_data = cache.get(level, cache[max(cache.keys())])[0] + + if image_data.startswith(b" Date: Fri, 28 Nov 2025 12:19:55 -0500 Subject: [PATCH 108/111] clean up --- .../catalyst/python_interface/visualization/draw.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/draw.py b/frontend/catalyst/python_interface/visualization/draw.py index 90f3984fa3..3d5ef5822d 100644 --- a/frontend/catalyst/python_interface/visualization/draw.py +++ b/frontend/catalyst/python_interface/visualization/draw.py @@ -131,7 +131,7 @@ def draw_graph(qnode: QNode, *, level: None | int = None) -> Callable: Returns: - Callable: A wrapper function that visualizes the QNode at the specified level. + ???? """ cache: dict[int, tuple[str, str]] = _cache_store.setdefault(qnode, {}) @@ -140,6 +140,7 @@ def _draw_callback(previous_pass, module, next_pass, pass_level=0): """Callback function for circuit drawing.""" pass_instance = previous_pass if previous_pass else next_pass + # Process module to build DAG utility = ConstructCircuitDAG(PyDotDAGBuilder()) utility.construct(module) svg_str = utility.dag_builder.graph.create_svg(prog="dot") @@ -153,13 +154,6 @@ def _draw_callback(previous_pass, module, next_pass, pass_level=0): @wraps(qnode) def wrapper(*args, **kwargs): - if args or kwargs: - warnings.warn( - "The `draw` function does not yet support dynamic arguments.\n" - "To visualize the circuit with dynamic parameters or wires, please use the\n" - "`compiler.python_compiler.visualization.generate_mlir_graph` function instead.", - UserWarning, - ) mlir_module = _get_mlir_module(qnode, args, kwargs) Compiler.run(mlir_module, callback=_draw_callback) @@ -182,6 +176,6 @@ def wrapper(*args, **kwargs): fig, ax = plt.subplots() ax.imshow(img) ax.set_axis_off() - return fig + return fig, ax return wrapper From 498fb1822b930af541aaf55f62ab57a64bb9da47 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 3 Dec 2025 15:05:07 -0500 Subject: [PATCH 109/111] cl --- doc/releases/changelog-dev.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3b2ededd4c..063de07b42 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -5,8 +5,9 @@ * Compiled programs can be visualized. [(#2213)](https://github.com/PennyLaneAI/catalyst/pull/2213) [(#2229)](https://github.com/PennyLaneAI/catalyst/pull/2229) - [(#2246)](https://github.com/PennyLaneAI/catalyst/pull/2246) [(#2214)](https://github.com/PennyLaneAI/catalyst/pull/2214) + [(#2246)](https://github.com/PennyLaneAI/catalyst/pull/2246) + [(#2243)](https://github.com/PennyLaneAI/catalyst/pull/2243) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) From f849bc4f4dd7344469b9dd914acbcbdc0cbb87a0 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 3 Dec 2025 15:05:39 -0500 Subject: [PATCH 110/111] format --- frontend/catalyst/python_interface/inspection/draw.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/frontend/catalyst/python_interface/inspection/draw.py b/frontend/catalyst/python_interface/inspection/draw.py index ad2c5a00ba..1e6056a81c 100644 --- a/frontend/catalyst/python_interface/inspection/draw.py +++ b/frontend/catalyst/python_interface/inspection/draw.py @@ -71,9 +71,7 @@ def _draw_callback(previous_pass, module, next_pass, pass_level=0): collector = QMLCollector(module) ops, meas = collector.collect() tape = QuantumScript(ops, meas) - pass_name = ( - pass_instance.name if hasattr(pass_instance, "name") else pass_instance - ) + pass_name = pass_instance.name if hasattr(pass_instance, "name") else pass_instance cache[pass_level] = ( tape.draw(show_matrices=False), pass_name if pass_level else "No transforms", @@ -131,9 +129,7 @@ def _draw_callback(previous_pass, module, next_pass, pass_level=0): utility = ConstructCircuitDAG(PyDotDAGBuilder()) utility.construct(module) svg_str = utility.dag_builder.graph.create_svg(prog="dot") - pass_name = ( - pass_instance.name if hasattr(pass_instance, "name") else pass_instance - ) + pass_name = pass_instance.name if hasattr(pass_instance, "name") else pass_instance cache[pass_level] = ( svg_str, pass_name if pass_level else "No transforms", From 50aceefcd2771d56212d757f3a4e3a1604b5cf30 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Wed, 3 Dec 2025 15:25:46 -0500 Subject: [PATCH 111/111] improve frontend --- .../python_interface/inspection/draw.py | 59 +++++++------------ 1 file changed, 21 insertions(+), 38 deletions(-) diff --git a/frontend/catalyst/python_interface/inspection/draw.py b/frontend/catalyst/python_interface/inspection/draw.py index 1e6056a81c..4dc251a280 100644 --- a/frontend/catalyst/python_interface/inspection/draw.py +++ b/frontend/catalyst/python_interface/inspection/draw.py @@ -99,25 +99,7 @@ def wrapper(*args, **kwargs): def draw_graph(qnode: QNode, *, level: None | int = None) -> Callable: """ - Draw the QNode at the specified level. - - This function can be used to visualize the QNode at different stages of the transformation - pipeline when xDSL or Catalyst compilation passes are applied. - If the specified level is not available, the highest level will be used as a fallback. - - The provided QNode is assumed to be decorated with compilation passes. - If no passes are applied, the original QNode is visualized. - - Args: - qnode (.QNode): the input QNode that is to be visualized. The QNode is assumed to be - compiled with ``qjit``. - level (None | int): the level of transformation to visualize. If `None`, the final - level is visualized. - - - Returns: - ???? - + ??? """ cache: dict[int, tuple[str, str]] = _cache_store.setdefault(qnode, {}) @@ -128,10 +110,11 @@ def _draw_callback(previous_pass, module, next_pass, pass_level=0): # Process module to build DAG utility = ConstructCircuitDAG(PyDotDAGBuilder()) utility.construct(module) - svg_str = utility.dag_builder.graph.create_svg(prog="dot") + # Store DAG in cache + image_bytes = utility.dag_builder.graph.create_png(prog="dot") pass_name = pass_instance.name if hasattr(pass_instance, "name") else pass_instance cache[pass_level] = ( - svg_str, + image_bytes, pass_name if pass_level else "No transforms", ) @@ -143,22 +126,22 @@ def wrapper(*args, **kwargs): if not cache: return None - # Retrieve the SVG string (or the high-DPI PNG string) - image_data = cache.get(level, cache[max(cache.keys())])[0] - - if image_data.startswith(b"