From 5ac829c110f54dfa374a9edc8dda57e8a212eae6 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 14:16:46 -0500 Subject: [PATCH 01/75] cl --- doc/releases/changelog-dev.md | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index fa2edc7412..65df9b174f 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -5,6 +5,7 @@ * Compiled programs can be visualized. [(#2213)](https://github.com/PennyLaneAI/catalyst/pull/2213) [(#2214)](https://github.com/PennyLaneAI/catalyst/pull/2214) + [(#)]() * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) From 855fa4c132a467d47b9ac79f753698b5f816f867 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:17:38 -0500 Subject: [PATCH 02/75] Apply suggestion from @andrijapau --- doc/releases/changelog-dev.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 65df9b174f..c5d7803d15 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -5,7 +5,7 @@ * Compiled programs can be visualized. [(#2213)](https://github.com/PennyLaneAI/catalyst/pull/2213) [(#2214)](https://github.com/PennyLaneAI/catalyst/pull/2214) - [(#)]() + [(#2118)](https://github.com/PennyLaneAI/catalyst/pull/2218) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) From 13cc8fc0632f9761b599475425a5b780b0996326 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 14:50:15 -0500 Subject: [PATCH 03/75] initial commit --- .../visualization/construct_circuit_dag.py | 85 +++++++++++++++++-- 1 file changed, 76 insertions(+), 9 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 1237f94fed..37eb7e3f77 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -14,11 +14,10 @@ """Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" -from functools import singledispatchmethod -from typing import TYPE_CHECKING, Any +from functools import singledispatch, singledispatchmethod +from typing import Any -import xdsl -from xdsl.dialects import builtin, func, scf +from xdsl.dialects import builtin, scf from xdsl.ir import Block, Region from catalyst.python_interface.dialects import quantum @@ -37,6 +36,14 @@ def __init__(self, dag_builder: DAGBuilder) -> None: """ self.dag_builder: DAGBuilder = dag_builder + # Record clusters seen as a stack + # beginning with the base graph (None) + self._cluster_stack: list[str | None] = [None] + + def _reset(self) -> None: + """Resets the instance.""" + self._cluster_stack: list[str | None] = [] + # ================================= # 1. CORE DISPATCH AND ENTRY POINT # ================================= @@ -45,10 +52,14 @@ def __init__(self, dag_builder: DAGBuilder) -> None: def visit_op(self, op: Any) -> None: """Central dispatch method (Visitor Pattern). Routes the operation 'op' to the specialized handler registered for its type.""" - raise NotImplementedError(f"Dispatch not registered for operator of type {type(op)}") + raise NotImplementedError( + f"Dispatch not registered for operator of type {type(op)}" + ) def construct(self, module: builtin.ModuleOp) -> None: """Constructs the DAG from the module.""" + self._reset() + for op in module.ops: self.visit_op(op) @@ -80,7 +91,12 @@ def _visit_unitary_and_state_prep( op: quantum.CustomOp, ) -> None: """Generic handler for unitary gates and quantum state preparation operations.""" - pass + + # Build node on graph + self.dag_builder.add_node( + node_id=f"node_{id(op)}", + node_label=get_label(op), + ) # ============================================= # 4. QUANTUM MEASUREMENT HANDLERS @@ -89,7 +105,12 @@ def _visit_unitary_and_state_prep( @visit_op.register def _visit_state_op(self, op: quantum.StateOp) -> None: """Handler for the terminal state measurement operation.""" - pass + + # Build node on graph + self.dag_builder.add_node( + node_id=f"node_{id(op)}", + node_label=get_label(op), + ) @visit_op.register def _visit_statistical_measurement_ops( @@ -97,12 +118,22 @@ def _visit_statistical_measurement_ops( op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp | quantum.SampleOp, ) -> None: """Handler for statistical measurement operations.""" - pass + + # Build node on graph + self.dag_builder.add_node( + node_id=f"node_{id(op)}", + node_label=get_label(op), + ) @visit_op.register def _visit_projective_measure_op(self, op: quantum.MeasureOp) -> None: """Handler for the single-qubit projective measurement operation.""" - pass + + # Build node on graph + self.dag_builder.add_node( + node_id=f"node_{id(op)}", + node_label=get_label(op), + ) # ========================= # 5. CONTROL FLOW HANDLERS @@ -122,3 +153,39 @@ def _visit_while_op(self, op: scf.WhileOp) -> None: def _visit_if_op(self, op: scf.IfOp) -> None: """Handle an xDSL IfOp operation.""" pass + + +@singledispatch +def get_label(op: Any) -> str: + """Gets a human readable label for a given xDSL operation. + + Returns: + label (str): The appropriate label for a given xDSL operation. Defaults + to the class name. + """ + return type(op).__name__.replace("Op", "") + + +@get_label.register +def _get_custom_op_label(op: quantum.CustomOp) -> str: + return op.gate_name.data + + +@get_label.register +def _get_state_op_label(op: quantum.StateOp) -> str: + return op.name + + +@get_label.register +def _get_statistical_measurement_op_label( + op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp | quantum.SampleOp, +) -> str: + # e.g. op -> expval(Z(0)) + obs = op.obs + mp = op.name.split(".")[-1] # quantum.expval -> expval + return f"{mp}({obs})" + + +@get_label.register +def _get_projective_measurement_op_label(op: quantum.MeasureOp) -> str: + return op.name.split(".")[-1] # quantum.measure -> measure From 6d46558ffe76a8799387fc092e95159b127b0393 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 14:51:11 -0500 Subject: [PATCH 04/75] whoops, forgot parent graph id --- .../python_interface/visualization/construct_circuit_dag.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 37eb7e3f77..6c19d71cbe 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -96,6 +96,7 @@ def _visit_unitary_and_state_prep( self.dag_builder.add_node( node_id=f"node_{id(op)}", node_label=get_label(op), + parent_graph_id=self._cluster_stack[-1], ) # ============================================= @@ -110,6 +111,7 @@ def _visit_state_op(self, op: quantum.StateOp) -> None: self.dag_builder.add_node( node_id=f"node_{id(op)}", node_label=get_label(op), + parent_graph_id=self._cluster_stack[-1], ) @visit_op.register @@ -123,6 +125,7 @@ def _visit_statistical_measurement_ops( self.dag_builder.add_node( node_id=f"node_{id(op)}", node_label=get_label(op), + parent_graph_id=self._cluster_stack[-1], ) @visit_op.register @@ -133,6 +136,7 @@ def _visit_projective_measure_op(self, op: quantum.MeasureOp) -> None: self.dag_builder.add_node( node_id=f"node_{id(op)}", node_label=get_label(op), + parent_graph_id=self._cluster_stack[-1], ) # ========================= From 518f756d38608f50d9f28a105d1ea71ac0948c17 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 15:45:03 -0500 Subject: [PATCH 05/75] only get labels for easy stuff for now --- .../visualization/construct_circuit_dag.py | 23 ++++++------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index fc16758a71..d512bcd1ef 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -171,7 +171,7 @@ def get_label(op: Any) -> str: label (str): The appropriate label for a given xDSL operation. Defaults to the class name. """ - return type(op).__name__.replace("Op", "") + return type(op).__name__ @get_label.register @@ -179,21 +179,12 @@ def _get_custom_op_label(op: quantum.CustomOp) -> str: return op.gate_name.data -@get_label.register -def _get_state_op_label(op: quantum.StateOp) -> str: - return op.name - - @get_label.register def _get_statistical_measurement_op_label( - op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp | quantum.SampleOp, + op: quantum.ExpvalOp | quantum.VarianceOp, ) -> str: - # e.g. op -> expval(Z(0)) - obs = op.obs - mp = op.name.split(".")[-1] # quantum.expval -> expval - return f"{mp}({obs})" - - -@get_label.register -def _get_projective_measurement_op_label(op: quantum.MeasureOp) -> str: - return op.name.split(".")[-1] # quantum.measure -> measure + # e.g. expval(Z(0)) should be the output + mp: str = op.name.split(".")[-1] # quantum.expval -> expval + obs: str = op.obs.owner.properties.get("type").data.value + wires: str = "" + return f"{mp}({obs}({wires}))" From aa5f7b6c0b664f8357f13065e9d352582991616d Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 15:50:15 -0500 Subject: [PATCH 06/75] typo --- .../python_interface/visualization/construct_circuit_dag.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index d512bcd1ef..add07e8eab 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -185,6 +185,7 @@ def _get_statistical_measurement_op_label( ) -> str: # e.g. expval(Z(0)) should be the output mp: str = op.name.split(".")[-1] # quantum.expval -> expval - obs: str = op.obs.owner.properties.get("type").data.value + obs_op = op.obs.owner + obs_name: str = obs_op.properties.get("type").data.value wires: str = "" - return f"{mp}({obs}({wires}))" + return f"{mp}({obs_name}({wires}))" From 5ffb87f32789741c37f00ec64c7214ed56b2a9c5 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 16:48:05 -0500 Subject: [PATCH 07/75] whoops --- .../python_interface/visualization/construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index add07e8eab..057f25f83e 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -42,7 +42,7 @@ def __init__(self, dag_builder: DAGBuilder) -> None: def _reset(self) -> None: """Resets the instance.""" - self._cluster_stack: list[str | None] = [] + self._cluster_stack: list[str | None] = [None] # ================================= # 1. CORE DISPATCH AND ENTRY POINT From 3a1f077d528ecebc34521b2b436927a2fd9a51a4 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 16:50:24 -0500 Subject: [PATCH 08/75] whoops --- .../python_interface/visualization/construct_circuit_dag.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 057f25f83e..8d80823f43 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -176,7 +176,9 @@ def get_label(op: Any) -> str: @get_label.register def _get_custom_op_label(op: quantum.CustomOp) -> str: - return op.gate_name.data + op_name: str = op.gate_name.data + op_wires: str = "" + return f"{op_name}({op_wires})" @get_label.register From f9a6100644ac6e31e9783cbf9c65626c1680d2c3 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 20 Nov 2025 16:50:53 -0500 Subject: [PATCH 09/75] whoops --- .../python_interface/visualization/construct_circuit_dag.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 8d80823f43..41501742f6 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -176,9 +176,9 @@ def get_label(op: Any) -> str: @get_label.register def _get_custom_op_label(op: quantum.CustomOp) -> str: - op_name: str = op.gate_name.data - op_wires: str = "" - return f"{op_name}({op_wires})" + name: str = op.gate_name.data + wires: str = "" + return f"{name}({wires})" @get_label.register From c85d5923ad0deac95f0f70c0f6dd3c16e9f849e4 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 10:20:40 -0500 Subject: [PATCH 10/75] add more support --- .../visualization/construct_circuit_dag.py | 57 ++++++------------- 1 file changed, 18 insertions(+), 39 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 41501742f6..d1192c1530 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -22,6 +22,10 @@ from catalyst.python_interface.dialects import quantum from catalyst.python_interface.visualization.dag_builder import DAGBuilder +from catalyst.python_interface.visualization.xdsl_conversion import ( + xdsl_to_qml_measurement, + xdsl_to_qml_op, +) class ConstructCircuitDAG: @@ -67,19 +71,19 @@ def construct(self, module: builtin.ModuleOp) -> None: # These methods navigate the recursive IR hierarchy (Op -> Region -> Block -> Op). @visit.register - def visit_operation(self, operation: Operation) -> None: + def _operation(self, operation: Operation) -> None: """Visit an xDSL Operation.""" for region in operation.regions: - self.visit_region(region) + self.visit(region) @visit.register - def visit_region(self, region: Region) -> None: + def _region(self, region: Region) -> None: """Visit an xDSL Region operation.""" for block in region.blocks: - self.visit_block(block) + self.visit(block) @visit.register - def visit_block(self, block: Block) -> None: + def _block(self, block: Block) -> None: """Visit an xDSL Block operation, dispatching handling for each contained Operation.""" for op in block.ops: self.visit(op) @@ -96,10 +100,11 @@ def _unitary_and_state_prep( ) -> None: """Generic handler for unitary gates and quantum state preparation operations.""" + qml_op = xdsl_to_qml_op(op) # Build node on graph self.dag_builder.add_node( node_id=f"node_{id(op)}", - node_label=get_label(op), + node_label=str(qml_op), parent_graph_id=self._cluster_stack[-1], ) @@ -111,10 +116,11 @@ def _unitary_and_state_prep( def _state_op(self, op: quantum.StateOp) -> None: """Handler for the terminal state measurement operation.""" + meas = xdsl_to_qml_measurement(op) # Build node on graph self.dag_builder.add_node( node_id=f"node_{id(op)}", - node_label=get_label(op), + node_label=str(meas), parent_graph_id=self._cluster_stack[-1], ) @@ -125,10 +131,12 @@ def _statistical_measurement_ops( ) -> None: """Handler for statistical measurement operations.""" + obs_op = op.obs.owner + meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) # Build node on graph self.dag_builder.add_node( node_id=f"node_{id(op)}", - node_label=get_label(op), + node_label=str(meas), parent_graph_id=self._cluster_stack[-1], ) @@ -136,10 +144,11 @@ def _statistical_measurement_ops( def _projective_measure_op(self, op: quantum.MeasureOp) -> None: """Handler for the single-qubit projective measurement operation.""" + meas = xdsl_to_qml_measurement(op) # Build node on graph self.dag_builder.add_node( node_id=f"node_{id(op)}", - node_label=get_label(op), + node_label=str(meas), parent_graph_id=self._cluster_stack[-1], ) @@ -161,33 +170,3 @@ def _while_op(self, op: scf.WhileOp) -> None: def _if_op(self, op: scf.IfOp) -> None: """Handle an xDSL IfOp operation.""" pass - - -@singledispatch -def get_label(op: Any) -> str: - """Gets a human readable label for a given xDSL operation. - - Returns: - label (str): The appropriate label for a given xDSL operation. Defaults - to the class name. - """ - return type(op).__name__ - - -@get_label.register -def _get_custom_op_label(op: quantum.CustomOp) -> str: - name: str = op.gate_name.data - wires: str = "" - return f"{name}({wires})" - - -@get_label.register -def _get_statistical_measurement_op_label( - op: quantum.ExpvalOp | quantum.VarianceOp, -) -> str: - # e.g. expval(Z(0)) should be the output - mp: str = op.name.split(".")[-1] # quantum.expval -> expval - obs_op = op.obs.owner - obs_name: str = obs_op.properties.get("type").data.value - wires: str = "" - return f"{mp}({obs_name}({wires}))" From 0a1950a7190d69a5a467ad550f3a3db583279c5e Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 10:57:11 -0500 Subject: [PATCH 11/75] add support for other ops --- .../visualization/construct_circuit_dag.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index d1192c1530..5fde761df5 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -96,7 +96,12 @@ def _block(self, block: Block) -> None: @visit.register def _unitary_and_state_prep( self, - op: quantum.CustomOp, + op: quantum.CustomOp + | quantum.GlobalPhaseOp + | quantum.QubitUnitaryOp + | quantum.SetStateOp + | quantum.MultiRZOp + | quantum.SetBasisStateOp, ) -> None: """Generic handler for unitary gates and quantum state preparation operations.""" From 7c1b495c84d11e6840fccc2ca432fa9f6545de6d Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 19:49:43 -0500 Subject: [PATCH 12/75] get rid of cluster stuff --- .../visualization/construct_circuit_dag.py | 19 ------------------- 1 file changed, 19 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 1736c2dc9f..e1b7bdd321 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -156,22 +156,3 @@ def _projective_measure_op(self, op: quantum.MeasureOp) -> None: node_label=str(meas), parent_graph_id=self._cluster_stack[-1], ) - - # ========================= - # 5. CONTROL FLOW HANDLERS - # ========================= - - @visit.register - def _for_op(self, op: scf.ForOp) -> None: - """Handle an xDSL ForOp operation.""" - pass - - @visit.register - def _while_op(self, op: scf.WhileOp) -> None: - """Handle an xDSL WhileOp operation.""" - pass - - @visit.register - def _if_op(self, op: scf.IfOp) -> None: - """Handle an xDSL IfOp operation.""" - pass From f19f60580614c515137ab8ce55efdbe865c59869 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 19:54:01 -0500 Subject: [PATCH 13/75] add test classes --- .../visualization/test_construct_circuit_dag.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 6d4b7106ea..d62b2ffc71 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -62,3 +62,9 @@ def test_does_not_mutate_module(): utility.construct(module_op) assert str(module_op) == module_op_str_before + +class TestOperatorNodes: + """Tests that operators can be visualized as nodes.""" + +class TestMeasurementNodes: + """Tests that measurements can be visualized as nodes.""" From 04413ec08d56aee444d2015b859be15ea2163a27 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 21 Nov 2025 19:58:38 -0500 Subject: [PATCH 14/75] add basic test skeleton --- .../visualization/test_construct_circuit_dag.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index d62b2ffc71..b815f04646 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -63,8 +63,16 @@ def test_does_not_mutate_module(): assert str(module_op) == module_op_str_before + class TestOperatorNodes: """Tests that operators can be visualized as nodes.""" + def test_custom_op(self): + pass + + class TestMeasurementNodes: """Tests that measurements can be visualized as nodes.""" + + def test_state_op(self): + pass From 314de1f0ae332918af7a074379900e314b9acff5 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 09:31:27 -0500 Subject: [PATCH 15/75] add more tests --- .../test_construct_circuit_dag.py | 29 ++++++++++++++++--- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index b815f04646..6d5541b1de 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -64,15 +64,36 @@ def test_does_not_mutate_module(): assert str(module_op) == module_op_str_before -class TestOperatorNodes: - """Tests that operators can be visualized as nodes.""" +class TestCreateOperatorNodes: + """Tests that operators can be created and visualized as nodes.""" def test_custom_op(self): pass + def test_global_phase_op(self): + pass + + def test_qubit_unitary_op(self): + pass + + def test_set_state_op(self): + pass + + def test_multi_rz_op(self): + pass + + def test_set_basis_state_op(self): + pass -class TestMeasurementNodes: - """Tests that measurements can be visualized as nodes.""" + +class TestCreateMeasurementNodes: + """Tests that measurements can be created and visualized as nodes.""" def test_state_op(self): pass + + def test_statistical_measurement_op(self): + pass + + def test_projective_measurement_op(self): + pass From 552aaf46062b7db260926a07ac8f9d5134c58b30 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 10:07:50 -0500 Subject: [PATCH 16/75] clean-up --- .../visualization/test_construct_circuit_dag.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 6d5541b1de..5c41387fb7 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -13,8 +13,7 @@ # limitations under the License. """Unit tests for the ConstructCircuitDAG utility.""" -from unittest import mock -from unittest.mock import MagicMock, Mock, call +from unittest.mock import Mock import pytest From c1cc80ab0710662cb2208e2c9275278ebf40dd11 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 11:29:47 -0500 Subject: [PATCH 17/75] update test --- .../test_construct_circuit_dag.py | 19 ++++++++++++++++++- 1 file changed, 18 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 5c41387fb7..52ba5e982c 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -21,10 +21,15 @@ # pylint: disable=wrong-import-position # This import needs to be after pytest in order to prevent ImportErrors +from catalyst.python_interface.dialects.quantum import CustomOp from catalyst.python_interface.visualization.construct_circuit_dag import ( ConstructCircuitDAG, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder +from pennylane.compiler.python_compiler.dialects.quantum import ( + CustomOp, + QubitType, +) from xdsl.dialects import test from xdsl.dialects.builtin import ModuleOp from xdsl.ir.core import Block, Region @@ -67,7 +72,19 @@ class TestCreateOperatorNodes: """Tests that operators can be created and visualized as nodes.""" def test_custom_op(self): - pass + """Tests that the CustomOp operation node can be created and visualized.""" + + # Create constant for wire + q0 = test.TestOp(result_types=[QubitType()]) + custom_op = CustomOp(gate_name="Test", in_qubits=q0.results) + + # Create module + module = ModuleOp(ops=[q0, custom_op]) + + mock_dag_builder = Mock(DAGBuilder) + utility = ConstructCircuitDAG(mock_dag_builder) + utility.construct(module) + print(utility.dag_builder.to_string()) def test_global_phase_op(self): pass From aa4d241ac5cabc0b9b9a88a3bacab62fa8568447 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 11:39:32 -0500 Subject: [PATCH 18/75] update test --- .../visualization/test_construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 52ba5e982c..af5636f4e3 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -26,7 +26,7 @@ ConstructCircuitDAG, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder -from pennylane.compiler.python_compiler.dialects.quantum import ( +from catalyst.python_interface.dialects.quantum import ( CustomOp, QubitType, ) From d418cb07772a26984e2ade1a1dc1752af7d2f1f8 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 11:46:56 -0500 Subject: [PATCH 19/75] clean up --- .../visualization/construct_circuit_dag.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index e1b7bdd321..63d1437cfd 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -93,7 +93,7 @@ def _block(self, block: Block) -> None: # ====================================== # Handlers for operations that apply unitary transformations or set-up the quantum state. - @visit.register + @_visit.register def _unitary_and_state_prep( self, op: quantum.CustomOp @@ -117,7 +117,7 @@ def _unitary_and_state_prep( # 4. QUANTUM MEASUREMENT HANDLERS # ============================================= - @visit.register + @_visit.register def _state_op(self, op: quantum.StateOp) -> None: """Handler for the terminal state measurement operation.""" @@ -129,7 +129,7 @@ def _state_op(self, op: quantum.StateOp) -> None: parent_graph_id=self._cluster_stack[-1], ) - @visit.register + @_visit.register def _statistical_measurement_ops( self, op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp | quantum.SampleOp, @@ -145,7 +145,7 @@ def _statistical_measurement_ops( parent_graph_id=self._cluster_stack[-1], ) - @visit.register + @_visit.register def _projective_measure_op(self, op: quantum.MeasureOp) -> None: """Handler for the single-qubit projective measurement operation.""" @@ -156,3 +156,4 @@ def _projective_measure_op(self, op: quantum.MeasureOp) -> None: node_label=str(meas), parent_graph_id=self._cluster_stack[-1], ) + From d3a5c0d4104af679c177df90ec7d01310e554759 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 12:54:23 -0500 Subject: [PATCH 20/75] temp test --- .../test_construct_circuit_dag.py | 29 ++++++++++++------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index af5636f4e3..a24921e2a2 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -21,15 +21,16 @@ # pylint: disable=wrong-import-position # This import needs to be after pytest in order to prevent ImportErrors -from catalyst.python_interface.dialects.quantum import CustomOp -from catalyst.python_interface.visualization.construct_circuit_dag import ( - ConstructCircuitDAG, -) -from catalyst.python_interface.visualization.dag_builder import DAGBuilder +import pennylane as qml +from catalyst.python_interface.conversion import xdsl_from_qjit from catalyst.python_interface.dialects.quantum import ( CustomOp, QubitType, ) +from catalyst.python_interface.visualization.construct_circuit_dag import ( + ConstructCircuitDAG, +) +from catalyst.python_interface.visualization.dag_builder import DAGBuilder from xdsl.dialects import test from xdsl.dialects.builtin import ModuleOp from xdsl.ir.core import Block, Region @@ -74,17 +75,23 @@ class TestCreateOperatorNodes: def test_custom_op(self): """Tests that the CustomOp operation node can be created and visualized.""" - # Create constant for wire - q0 = test.TestOp(result_types=[QubitType()]) - custom_op = CustomOp(gate_name="Test", in_qubits=q0.results) + # Build module with only a CustomOp + dev = qml.device("null.qubit", wires=1) - # Create module - module = ModuleOp(ops=[q0, custom_op]) + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.H(0) + module = my_circuit() + + # Construct DAG mock_dag_builder = Mock(DAGBuilder) utility = ConstructCircuitDAG(mock_dag_builder) utility.construct(module) - print(utility.dag_builder.to_string()) + + # Ensure DAG only has one node def test_global_phase_op(self): pass From 4bb4f38196f7ea00018dd46fccbb49c5e22f8d5c Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 16:28:45 -0500 Subject: [PATCH 21/75] update custom op test --- .../visualization/test_construct_circuit_dag.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index a24921e2a2..d191088615 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -92,6 +92,9 @@ def my_circuit(): utility.construct(module) # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 1 + assert next(iter(nodes.values()))["label"] == "H(0)" def test_global_phase_op(self): pass From 9580d9ead825e4b75b42aac3f2debbf4e93ff4c2 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 17:04:21 -0500 Subject: [PATCH 22/75] update test with fake dag builder --- .../test_construct_circuit_dag.py | 62 ++++++++++++++++++- 1 file changed, 60 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index d191088615..450e54364f 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -36,6 +36,65 @@ from xdsl.ir.core import Block, Region +class FakeDAGBuilder(DAGBuilder): + """ + A concrete implementation of DAGBuilder used ONLY for testing. + It stores all graph manipulation calls in simple Python dictionaries + for easy assertion of the final graph state. + """ + + def __init__(self): + self._nodes = {} + self._edges = [] + self._clusters = {} + + def add_node(self, id, label, cluster_id=None, **attrs) -> None: + self._nodes[id] = { + "id": id, + "label": label, + "parent_id": cluster_id, + "attrs": attrs, + } + + def add_edge(self, from_id: str, to_id: str, **attrs) -> None: + self._edges.append( + { + "from": from_id, + "to": to_id, + "attrs": attrs, + } + ) + + def add_cluster( + self, + id, + node_label=None, + cluster_id=None, + **attrs, + ) -> None: + self._clusters[id] = { + "id": id, + "label": node_label, + "parent_id": cluster_id, + "attrs": attrs, + } + + def get_nodes(self): + return self._nodes.copy() + + def get_edges(self): + return self._edges.copy() + + def get_clusters(self): + return self._clusters.copy() + + def to_file(self, output_filename): + pass + + def to_string(self) -> str: + return "graph" + + class TestInitialization: """Tests that the state is correctly initialized.""" @@ -87,8 +146,7 @@ def my_circuit(): module = my_circuit() # Construct DAG - mock_dag_builder = Mock(DAGBuilder) - utility = ConstructCircuitDAG(mock_dag_builder) + utility = ConstructCircuitDAG(FakeDAGBuilder()) utility.construct(module) # Ensure DAG only has one node From 2ff7723e42bcf4a705e53b65f9bd25d8b1e55f0a Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 17:06:07 -0500 Subject: [PATCH 23/75] parametrize over ops --- .../visualization/test_construct_circuit_dag.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 450e54364f..e48f31c98e 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -131,7 +131,9 @@ def test_does_not_mutate_module(): class TestCreateOperatorNodes: """Tests that operators can be created and visualized as nodes.""" - def test_custom_op(self): + @pytest.mark.unit + @pytest.mark.parametrize("op", [qml.H(0), qml.X(0)]) + def test_custom_op(self, op): """Tests that the CustomOp operation node can be created and visualized.""" # Build module with only a CustomOp @@ -141,7 +143,7 @@ def test_custom_op(self): @qml.qjit(autograph=True, target="mlir") @qml.qnode(dev) def my_circuit(): - qml.H(0) + qml.apply(op) module = my_circuit() From fac9f27ab001991f3e8184330006fb0437a2d5bd Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 17:21:25 -0500 Subject: [PATCH 24/75] improvements --- .../visualization/test_construct_circuit_dag.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index e48f31c98e..37d3ceea66 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -151,6 +151,12 @@ def my_circuit(): utility = ConstructCircuitDAG(FakeDAGBuilder()) utility.construct(module) + # sanity check + edges = utility.dag_builder.get_edges() + assert edges == [] + clusters = utility.dag_builder.get_clusters() + assert clusters == {} + # Ensure DAG only has one node nodes = utility.dag_builder.get_nodes() assert len(nodes) == 1 From e3f2e58206ea6f5e6f4f48edb34971ba5e49453a Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 17:23:04 -0500 Subject: [PATCH 25/75] whoops --- .../visualization/construct_circuit_dag.py | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 63d1437cfd..b0d9a47555 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -108,9 +108,9 @@ def _unitary_and_state_prep( qml_op = xdsl_to_qml_op(op) # Build node on graph self.dag_builder.add_node( - node_id=f"node_{id(op)}", - node_label=str(qml_op), - parent_graph_id=self._cluster_stack[-1], + id=f"node_{id(op)}", + label=str(qml_op), + cluster_id=self._cluster_stack[-1], ) # ============================================= @@ -124,9 +124,9 @@ def _state_op(self, op: quantum.StateOp) -> None: meas = xdsl_to_qml_measurement(op) # Build node on graph self.dag_builder.add_node( - node_id=f"node_{id(op)}", - node_label=str(meas), - parent_graph_id=self._cluster_stack[-1], + id=f"node_{id(op)}", + label=str(meas), + cluster_id=self._cluster_stack[-1], ) @_visit.register @@ -140,9 +140,9 @@ def _statistical_measurement_ops( meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) # Build node on graph self.dag_builder.add_node( - node_id=f"node_{id(op)}", - node_label=str(meas), - parent_graph_id=self._cluster_stack[-1], + id=f"node_{id(op)}", + label=str(meas), + cluster_id=self._cluster_stack[-1], ) @_visit.register @@ -152,8 +152,7 @@ def _projective_measure_op(self, op: quantum.MeasureOp) -> None: meas = xdsl_to_qml_measurement(op) # Build node on graph self.dag_builder.add_node( - node_id=f"node_{id(op)}", - node_label=str(meas), - parent_graph_id=self._cluster_stack[-1], + id=f"node_{id(op)}", + label=str(meas), + cluster_id=self._cluster_stack[-1], ) - From ed5316e6eb0cb6b1ebe3a8c36c7d0a1af5d16caa Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 17:38:19 -0500 Subject: [PATCH 26/75] fix and add test --- .../test_construct_circuit_dag.py | 39 +++++++++++++++++-- 1 file changed, 36 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 37d3ceea66..5d71c077b9 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -160,10 +160,43 @@ def my_circuit(): # Ensure DAG only has one node nodes = utility.dag_builder.get_nodes() assert len(nodes) == 1 - assert next(iter(nodes.values()))["label"] == "H(0)" + assert next(iter(nodes.values()))["label"] == str(op) - def test_global_phase_op(self): - pass + @pytest.mark.unit + @pytest.mark.parametrize( + "op", + [ + qml.GlobalPhase(0.5), + qml.GlobalPhase(0.5, wires=0), + qml.GlobalPhase(0.5, wires=[0, 1]), + ], + ) + def test_global_phase_op(self, op): + # Build module with only a CustomOp + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.apply(op) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # sanity check + edges = utility.dag_builder.get_edges() + assert edges == [] + clusters = utility.dag_builder.get_clusters() + assert clusters == {} + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 1 + assert next(iter(nodes.values()))["label"] == str(op) def test_qubit_unitary_op(self): pass From fe25ff6c0826366d72ea75bb619846581086b939 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 17:39:58 -0500 Subject: [PATCH 27/75] add test case --- .../visualization/test_construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 5d71c077b9..2f9ed944f0 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -132,7 +132,7 @@ class TestCreateOperatorNodes: """Tests that operators can be created and visualized as nodes.""" @pytest.mark.unit - @pytest.mark.parametrize("op", [qml.H(0), qml.X(0)]) + @pytest.mark.parametrize("op", [qml.H(0), qml.X(0), qml.SWAP([0, 1])]) def test_custom_op(self, op): """Tests that the CustomOp operation node can be created and visualized.""" From cb899951be6582afb2d4a22684b6b92b2c3eb281 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 17:58:59 -0500 Subject: [PATCH 28/75] more tests --- .../test_construct_circuit_dag.py | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 2f9ed944f0..b48da63bd5 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -172,7 +172,8 @@ def my_circuit(): ], ) def test_global_phase_op(self, op): - # Build module with only a CustomOp + """Test that GlobalPhase can be handled.""" + dev = qml.device("null.qubit", wires=1) @xdsl_from_qjit @@ -214,8 +215,33 @@ def test_set_basis_state_op(self): class TestCreateMeasurementNodes: """Tests that measurements can be created and visualized as nodes.""" - def test_state_op(self): - pass + @pytest.mark.parametrize("meas", [qml.state(), qml.state(wires=[0, 1])]) + def test_state_op(self, meas): + """Test that qml.state can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return meas + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # sanity check + edges = utility.dag_builder.get_edges() + assert edges == [] + clusters = utility.dag_builder.get_clusters() + assert clusters == {} + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 1 + assert next(iter(nodes.values()))["label"] == str(meas) def test_statistical_measurement_op(self): pass From e25b016bef95d54ba9fa37030d89093bf0867b19 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 17:59:45 -0500 Subject: [PATCH 29/75] cleanup --- .../visualization/test_construct_circuit_dag.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index b48da63bd5..5b860cb59d 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -95,15 +95,12 @@ def to_string(self) -> str: return "graph" -class TestInitialization: - """Tests that the state is correctly initialized.""" +def test_dependency_injection(): + """Tests that relevant dependencies are injected.""" - def test_dependency_injection(self): - """Tests that relevant dependencies are injected.""" - - mock_dag_builder = Mock(DAGBuilder) - utility = ConstructCircuitDAG(mock_dag_builder) - assert utility.dag_builder is mock_dag_builder + mock_dag_builder = Mock(DAGBuilder) + utility = ConstructCircuitDAG(mock_dag_builder) + assert utility.dag_builder is mock_dag_builder def test_does_not_mutate_module(): From d766ddd4cf5ff4d2951f924a4cb01b4fd861491e Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 18:21:26 -0500 Subject: [PATCH 30/75] add expval var test --- .../test_construct_circuit_dag.py | 38 ++++++++++++++++--- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 5b860cb59d..ed8faf8d76 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -212,8 +212,8 @@ def test_set_basis_state_op(self): class TestCreateMeasurementNodes: """Tests that measurements can be created and visualized as nodes.""" - @pytest.mark.parametrize("meas", [qml.state(), qml.state(wires=[0, 1])]) - def test_state_op(self, meas): + @pytest.mark.unit + def test_state_op(self): """Test that qml.state can be handled.""" dev = qml.device("null.qubit", wires=1) @@ -221,7 +221,7 @@ def test_state_op(self, meas): @qml.qjit(autograph=True, target="mlir") @qml.qnode(dev) def my_circuit(): - return meas + return qml.state() module = my_circuit() @@ -238,10 +238,36 @@ def my_circuit(): # Ensure DAG only has one node nodes = utility.dag_builder.get_nodes() assert len(nodes) == 1 - assert next(iter(nodes.values()))["label"] == str(meas) + assert next(iter(nodes.values()))["label"] == str(qml.state()) - def test_statistical_measurement_op(self): - pass + @pytest.mark.unit + @pytest.mark.parametrize("meas_fn", [qml.expval, qml.var]) + def test_expval_var_measurement_op(self, meas_fn): + """Test that statistical measurement operators can be captured as nodes.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return meas_fn(qml.Z(0)) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # sanity check + edges = utility.dag_builder.get_edges() + assert edges == [] + clusters = utility.dag_builder.get_clusters() + assert clusters == {} + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 1 + assert next(iter(nodes.values()))["label"] == str(qml.state()) def test_projective_measurement_op(self): pass From 1db0c6c75e16fbc7880d352f09a3ad5135206114 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 18:22:13 -0500 Subject: [PATCH 31/75] whoops typo --- .../visualization/test_construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index ed8faf8d76..3af5dc7839 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -267,7 +267,7 @@ def my_circuit(): # Ensure DAG only has one node nodes = utility.dag_builder.get_nodes() assert len(nodes) == 1 - assert next(iter(nodes.values()))["label"] == str(qml.state()) + assert next(iter(nodes.values()))["label"] == str(meas_fn(qml.Z(0))) def test_projective_measurement_op(self): pass From fbf9b04f38d0d2e33ef6011fd0abda35c5eb3523 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 18:24:04 -0500 Subject: [PATCH 32/75] add test for mid measure --- .../test_construct_circuit_dag.py | 27 ++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 3af5dc7839..aa307109b6 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -22,6 +22,7 @@ # pylint: disable=wrong-import-position # This import needs to be after pytest in order to prevent ImportErrors import pennylane as qml +from catalyst import measure from catalyst.python_interface.conversion import xdsl_from_qjit from catalyst.python_interface.dialects.quantum import ( CustomOp, @@ -270,4 +271,28 @@ def my_circuit(): assert next(iter(nodes.values()))["label"] == str(meas_fn(qml.Z(0))) def test_projective_measurement_op(self): - pass + """Test that projective measurements can be captured as nodes.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + measure(0) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # sanity check + edges = utility.dag_builder.get_edges() + assert edges == [] + clusters = utility.dag_builder.get_clusters() + assert clusters == {} + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 1 + assert "MidMeasure" in next(iter(nodes.values()))["label"] From 3576c81787633795c4ca3208e1e027d1b216af4d Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 18:25:59 -0500 Subject: [PATCH 33/75] add more tests --- .../test_construct_circuit_dag.py | 58 +++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index aa307109b6..42fd4adce3 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -270,6 +270,64 @@ def my_circuit(): assert len(nodes) == 1 assert next(iter(nodes.values()))["label"] == str(meas_fn(qml.Z(0))) + @pytest.mark.unit + def test_probs_measurement_op(self): + """Tests that the probs measurement function can be captured as a node.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return qml.probs() + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # sanity check + edges = utility.dag_builder.get_edges() + assert edges == [] + clusters = utility.dag_builder.get_clusters() + assert clusters == {} + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 1 + assert next(iter(nodes.values()))["label"] == str(qml.probs()) + + @pytest.mark.unit + def test_sample_measurement_op(self): + """Tests that the sample measurement function can be captured as a node.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.set_shots(10) + @qml.qnode(dev) + def my_circuit(): + return qml.sample() + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # sanity check + edges = utility.dag_builder.get_edges() + assert edges == [] + clusters = utility.dag_builder.get_clusters() + assert clusters == {} + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 1 + assert next(iter(nodes.values()))["label"] == str(qml.sample()) + + @pytest.mark.unit def test_projective_measurement_op(self): """Test that projective measurements can be captured as nodes.""" dev = qml.device("null.qubit", wires=1) From 9df1f89fc946cce93ced37b76a43f4f4c38f6c09 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Mon, 24 Nov 2025 18:27:23 -0500 Subject: [PATCH 34/75] clarify tests --- .../visualization/test_construct_circuit_dag.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 42fd4adce3..92694b87f0 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -126,8 +126,8 @@ def test_does_not_mutate_module(): assert str(module_op) == module_op_str_before -class TestCreateOperatorNodes: - """Tests that operators can be created and visualized as nodes.""" +class TestCreateStaticOperatorNodes: + """Tests that operators with static parameters can be created and visualized as nodes.""" @pytest.mark.unit @pytest.mark.parametrize("op", [qml.H(0), qml.X(0), qml.SWAP([0, 1])]) @@ -210,8 +210,8 @@ def test_set_basis_state_op(self): pass -class TestCreateMeasurementNodes: - """Tests that measurements can be created and visualized as nodes.""" +class TestCreateStaticMeasurementNodes: + """Tests that measurements with static parameters can be created and visualized as nodes.""" @pytest.mark.unit def test_state_op(self): From f5f1a9b3fc917e4e83923ac526c2dafb30b441a5 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 08:53:18 -0500 Subject: [PATCH 35/75] cleanup and tests --- .../visualization/construct_circuit_dag.py | 22 +++----- .../test_construct_circuit_dag.py | 56 ++++++++++++++++++- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index b0d9a47555..6f5f3062c5 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -66,9 +66,8 @@ def construct(self, module: builtin.ModuleOp) -> None: self._visit(op) # ======================= - # 2. HIERARCHY TRAVERSAL + # 2. IR TRAVERSAL # ======================= - # These methods navigate the recursive IR hierarchy (Op -> Region -> Block -> Op). @_visit.register def _operation(self, operation: Operation) -> None: @@ -88,20 +87,17 @@ def _block(self, block: Block) -> None: for op in block.ops: self._visit(op) - # ====================================== - # 3. QUANTUM GATE & STATE PREP HANDLERS - # ====================================== - # Handlers for operations that apply unitary transformations or set-up the quantum state. + # ====================== + # 3. QUANTUM OPERATIONS + # ====================== @_visit.register - def _unitary_and_state_prep( + def _unitary( self, op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp - | quantum.SetStateOp - | quantum.MultiRZOp - | quantum.SetBasisStateOp, + | quantum.MultiRZOp, ) -> None: """Generic handler for unitary gates and quantum state preparation operations.""" @@ -113,9 +109,9 @@ def _unitary_and_state_prep( cluster_id=self._cluster_stack[-1], ) - # ============================================= - # 4. QUANTUM MEASUREMENT HANDLERS - # ============================================= + # ======================== + # 4. QUANTUM MEASUREMENTS + # ======================== @_visit.register def _state_op(self, op: quantum.StateOp) -> None: diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 92694b87f0..7f28d6baf8 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -198,12 +198,62 @@ def my_circuit(): assert next(iter(nodes.values()))["label"] == str(op) def test_qubit_unitary_op(self): - pass + """Test that QubitUnitary operations can be handled.""" + dev = qml.device("null.qubit", wires=1) - def test_set_state_op(self): - pass + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.QubitUnitary([[0, 1], [1, 0]], wires=0) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # sanity check + edges = utility.dag_builder.get_edges() + assert edges == [] + clusters = utility.dag_builder.get_clusters() + assert clusters == {} + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 1 + assert next(iter(nodes.values()))["label"] == str( + qml.QubitUnitary([[0, 1], [1, 0]], wires=0) + ) def test_multi_rz_op(self): + """Test that MultiRZ operations can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.MultiRZ(0.5, wires=[0]) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # sanity check + edges = utility.dag_builder.get_edges() + assert edges == [] + clusters = utility.dag_builder.get_clusters() + assert clusters == {} + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 1 + assert next(iter(nodes.values()))["label"] == str(qml.MultiRZ(0.5, wires=[0])) + + def test_set_state_op(self): pass def test_set_basis_state_op(self): From b4a0cb26397fc0bb5f98c3488053f83ed8182aab Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 08:58:16 -0500 Subject: [PATCH 36/75] clean-up --- .../visualization/test_construct_circuit_dag.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 7f28d6baf8..7fe6681f2b 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -95,7 +95,7 @@ def to_file(self, output_filename): def to_string(self) -> str: return "graph" - +@pytest.mark.unit def test_dependency_injection(): """Tests that relevant dependencies are injected.""" @@ -103,7 +103,7 @@ def test_dependency_injection(): utility = ConstructCircuitDAG(mock_dag_builder) assert utility.dag_builder is mock_dag_builder - +@pytest.mark.unit def test_does_not_mutate_module(): """Test that the module is not mutated.""" @@ -197,6 +197,7 @@ def my_circuit(): assert len(nodes) == 1 assert next(iter(nodes.values()))["label"] == str(op) + @pytest.mark.unit def test_qubit_unitary_op(self): """Test that QubitUnitary operations can be handled.""" dev = qml.device("null.qubit", wires=1) @@ -226,6 +227,7 @@ def my_circuit(): qml.QubitUnitary([[0, 1], [1, 0]], wires=0) ) + @pytest.mark.unit def test_multi_rz_op(self): """Test that MultiRZ operations can be handled.""" dev = qml.device("null.qubit", wires=1) @@ -253,12 +255,6 @@ def my_circuit(): assert len(nodes) == 1 assert next(iter(nodes.values()))["label"] == str(qml.MultiRZ(0.5, wires=[0])) - def test_set_state_op(self): - pass - - def test_set_basis_state_op(self): - pass - class TestCreateStaticMeasurementNodes: """Tests that measurements with static parameters can be created and visualized as nodes.""" From 3bdf9ffabd15e1ebaaf3c32240ccdfcd317a628d Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:01:05 -0500 Subject: [PATCH 37/75] fix mock dagbuilder --- .../visualization/test_construct_circuit_dag.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 7fe6681f2b..ec19946ffe 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -50,6 +50,7 @@ def __init__(self): self._clusters = {} def add_node(self, id, label, cluster_id=None, **attrs) -> None: + cluster_id = "__base__" if cluster_id is None else cluster_id self._nodes[id] = { "id": id, "label": label, @@ -73,6 +74,7 @@ def add_cluster( cluster_id=None, **attrs, ) -> None: + cluster_id = "__base__" if cluster_id is None else cluster_id self._clusters[id] = { "id": id, "label": node_label, @@ -95,6 +97,7 @@ def to_file(self, output_filename): def to_string(self) -> str: return "graph" + @pytest.mark.unit def test_dependency_injection(): """Tests that relevant dependencies are injected.""" @@ -103,6 +106,7 @@ def test_dependency_injection(): utility = ConstructCircuitDAG(mock_dag_builder) assert utility.dag_builder is mock_dag_builder + @pytest.mark.unit def test_does_not_mutate_module(): """Test that the module is not mutated.""" From 7dc249119bb531379560d3641b1dc03793cb97e3 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:02:14 -0500 Subject: [PATCH 38/75] whoops --- .../visualization/test_construct_circuit_dag.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index ec19946ffe..6bf274e892 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -54,7 +54,7 @@ def add_node(self, id, label, cluster_id=None, **attrs) -> None: self._nodes[id] = { "id": id, "label": label, - "parent_id": cluster_id, + "cluster_id": cluster_id, "attrs": attrs, } @@ -78,7 +78,7 @@ def add_cluster( self._clusters[id] = { "id": id, "label": node_label, - "parent_id": cluster_id, + "cluster_id": cluster_id, "attrs": attrs, } From 464ed8bac18b00d3bc1fb7e54cf7333006181259 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:10:33 -0500 Subject: [PATCH 39/75] clean-up --- .../visualization/test_construct_circuit_dag.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 6bf274e892..7e5f542732 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -102,31 +102,31 @@ def to_string(self) -> str: def test_dependency_injection(): """Tests that relevant dependencies are injected.""" - mock_dag_builder = Mock(DAGBuilder) - utility = ConstructCircuitDAG(mock_dag_builder) - assert utility.dag_builder is mock_dag_builder + dag_builder = FakeDAGBuilder() + utility = ConstructCircuitDAG(dag_builder) + assert utility.dag_builder is dag_builder @pytest.mark.unit def test_does_not_mutate_module(): """Test that the module is not mutated.""" - # Create block containing some ops + # Create module op = test.TestOp() block = Block(ops=[op]) - # Create region containing some blocks region = Region(blocks=[block]) - # Create op containing the regions container_op = test.TestOp(regions=[region]) - # Create module op to house it all module_op = ModuleOp(ops=[container_op]) + # Save state before module_op_str_before = str(module_op) + # Process module mock_dag_builder = Mock(DAGBuilder) utility = ConstructCircuitDAG(mock_dag_builder) utility.construct(module_op) + # Ensure not mutated assert str(module_op) == module_op_str_before From 92b1c19d64746a8d5e8d6b57b0da1bff9a80c6a7 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:15:01 -0500 Subject: [PATCH 40/75] improve testing --- .../visualization/test_construct_circuit_dag.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 7e5f542732..b3b69dafc1 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -126,7 +126,7 @@ def test_does_not_mutate_module(): utility = ConstructCircuitDAG(mock_dag_builder) utility.construct(module_op) - # Ensure not mutated + # Ensure not mutated assert str(module_op) == module_op_str_before @@ -163,6 +163,7 @@ def my_circuit(): nodes = utility.dag_builder.get_nodes() assert len(nodes) == 1 assert next(iter(nodes.values()))["label"] == str(op) + assert next(iter(nodes.values()))["cluster_id"] == "__base__" @pytest.mark.unit @pytest.mark.parametrize( @@ -200,6 +201,7 @@ def my_circuit(): nodes = utility.dag_builder.get_nodes() assert len(nodes) == 1 assert next(iter(nodes.values()))["label"] == str(op) + assert next(iter(nodes.values()))["cluster_id"] == "__base__" @pytest.mark.unit def test_qubit_unitary_op(self): @@ -230,6 +232,7 @@ def my_circuit(): assert next(iter(nodes.values()))["label"] == str( qml.QubitUnitary([[0, 1], [1, 0]], wires=0) ) + assert next(iter(nodes.values()))["cluster_id"] == "__base__" @pytest.mark.unit def test_multi_rz_op(self): @@ -258,6 +261,7 @@ def my_circuit(): nodes = utility.dag_builder.get_nodes() assert len(nodes) == 1 assert next(iter(nodes.values()))["label"] == str(qml.MultiRZ(0.5, wires=[0])) + assert next(iter(nodes.values()))["cluster_id"] == "__base__" class TestCreateStaticMeasurementNodes: @@ -290,6 +294,7 @@ def my_circuit(): nodes = utility.dag_builder.get_nodes() assert len(nodes) == 1 assert next(iter(nodes.values()))["label"] == str(qml.state()) + assert next(iter(nodes.values()))["cluster_id"] == "__base__" @pytest.mark.unit @pytest.mark.parametrize("meas_fn", [qml.expval, qml.var]) @@ -319,6 +324,7 @@ def my_circuit(): nodes = utility.dag_builder.get_nodes() assert len(nodes) == 1 assert next(iter(nodes.values()))["label"] == str(meas_fn(qml.Z(0))) + assert next(iter(nodes.values()))["cluster_id"] == "__base__" @pytest.mark.unit def test_probs_measurement_op(self): @@ -347,6 +353,7 @@ def my_circuit(): nodes = utility.dag_builder.get_nodes() assert len(nodes) == 1 assert next(iter(nodes.values()))["label"] == str(qml.probs()) + assert next(iter(nodes.values()))["cluster_id"] == "__base__" @pytest.mark.unit def test_sample_measurement_op(self): @@ -376,6 +383,7 @@ def my_circuit(): nodes = utility.dag_builder.get_nodes() assert len(nodes) == 1 assert next(iter(nodes.values()))["label"] == str(qml.sample()) + assert next(iter(nodes.values()))["cluster_id"] == "__base__" @pytest.mark.unit def test_projective_measurement_op(self): @@ -404,3 +412,4 @@ def my_circuit(): nodes = utility.dag_builder.get_nodes() assert len(nodes) == 1 assert "MidMeasure" in next(iter(nodes.values()))["label"] + assert next(iter(nodes.values()))["cluster_id"] == "__base__" From 5eaf7df1040b710974e7b862b4097a88d269747a Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:19:48 -0500 Subject: [PATCH 41/75] cleanup --- .../python_interface/visualization/construct_circuit_dag.py | 4 ++-- .../visualization/test_construct_circuit_dag.py | 4 ---- 2 files changed, 2 insertions(+), 6 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index c133f1ffe2..dab7b0d304 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -14,10 +14,10 @@ """Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" -from functools import singledispatch, singledispatchmethod +from functools import singledispatchmethod from typing import Any -from xdsl.dialects import builtin, scf +from xdsl.dialects import builtin from xdsl.ir import Block, Operation, Region from catalyst.python_interface.dialects import quantum diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index b3b69dafc1..e03e8f1480 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -24,10 +24,6 @@ import pennylane as qml from catalyst import measure from catalyst.python_interface.conversion import xdsl_from_qjit -from catalyst.python_interface.dialects.quantum import ( - CustomOp, - QubitType, -) from catalyst.python_interface.visualization.construct_circuit_dag import ( ConstructCircuitDAG, ) From 70f8b80e178462bf4c8c210737aba85558ee2151 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Tue, 25 Nov 2025 09:21:42 -0500 Subject: [PATCH 42/75] Apply suggestion from @andrijapau --- .../python_interface/visualization/construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index dab7b0d304..4b008687b3 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -98,7 +98,7 @@ def _unitary( | quantum.QubitUnitaryOp | quantum.MultiRZOp, ) -> None: - """Generic handler for unitary gates and quantum state preparation operations.""" + """Generic handler for unitary gates.""" qml_op = xdsl_to_qml_op(op) # Build node on graph From 4f3814a3a7ba17e4fc5d350d6c1be08ba5d37aaf Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:26:48 -0500 Subject: [PATCH 43/75] run black --- .../python_interface/visualization/construct_circuit_dag.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 4415603b76..801263c768 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -96,10 +96,7 @@ def _block(self, block: Block) -> None: @_visit.register def _unitary( self, - op: quantum.CustomOp - | quantum.GlobalPhaseOp - | quantum.QubitUnitaryOp - | quantum.MultiRZOp, + op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, ) -> None: """Generic handler for unitary gates.""" From 91c11a3ae671e37bfc3cda8b78f0f6db748cc413 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:30:11 -0500 Subject: [PATCH 44/75] isort --- .../visualization/test_construct_circuit_dag.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 646f0eb942..4c2430934d 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -19,13 +19,13 @@ pytestmark = pytest.mark.usefixtures("requires_xdsl") +# pylint: disable=wrong-import-position +# This import needs to be after pytest in order to prevent ImportErrors +import pennylane as qml from xdsl.dialects import test from xdsl.dialects.builtin import ModuleOp from xdsl.ir.core import Block, Region -# pylint: disable=wrong-import-position -# This import needs to be after pytest in order to prevent ImportErrors -import pennylane as qml from catalyst import measure from catalyst.python_interface.conversion import xdsl_from_qjit from catalyst.python_interface.visualization.construct_circuit_dag import ( From 41327e0c97f5135af7be9ef73859bd712a52550c Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:39:11 -0500 Subject: [PATCH 45/75] fix import --- .../python_interface/visualization/construct_circuit_dag.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 801263c768..92e375ce45 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -20,6 +20,7 @@ from xdsl.dialects import builtin from xdsl.ir import Block, Operation, Region +from catalyst.python_interface.dialects import quantum from catalyst.python_interface.visualization.dag_builder import DAGBuilder from catalyst.python_interface.visualization.xdsl_conversion import ( xdsl_to_qml_measurement, @@ -96,7 +97,10 @@ def _block(self, block: Block) -> None: @_visit.register def _unitary( self, - op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, + op: quantum.CustomOp + | quantum.GlobalPhaseOp + | quantum.QubitUnitaryOp + | quantum.MultiRZOp, ) -> None: """Generic handler for unitary gates.""" From bcba0b5bbbbece4bfd22f528c94a05e6bf92e421 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 25 Nov 2025 09:42:50 -0500 Subject: [PATCH 46/75] black --- .../python_interface/visualization/construct_circuit_dag.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 92e375ce45..123a8addf0 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -97,10 +97,7 @@ def _block(self, block: Block) -> None: @_visit.register def _unitary( self, - op: quantum.CustomOp - | quantum.GlobalPhaseOp - | quantum.QubitUnitaryOp - | quantum.MultiRZOp, + op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, ) -> None: """Generic handler for unitary gates.""" From e9c69d89f923b4be5c5e52c971bfda5393e9f279 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 2 Dec 2025 16:20:18 -0500 Subject: [PATCH 47/75] format --- .../python_interface/visualization/construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index ef43381563..2504779611 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -145,7 +145,7 @@ def _projective_measure_op(self, op: quantum.MeasureOp) -> None: label=str(meas), cluster_id=self._cluster_stack[-1], ) - + # ============= # CONTROL FLOW # ============= From f47b1a5bb63613d290f5e9207f3b771e978d34da Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 2 Dec 2025 16:24:33 -0500 Subject: [PATCH 48/75] fix id to uid --- .../visualization/construct_circuit_dag.py | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 2504779611..4f34e7758c 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -16,15 +16,14 @@ from functools import singledispatchmethod -from xdsl.dialects import builtin, func, scf -from xdsl.ir import Block, Operation, Region, SSAValue - from catalyst.python_interface.dialects import catalyst, quantum from catalyst.python_interface.visualization.dag_builder import DAGBuilder from catalyst.python_interface.visualization.xdsl_conversion import ( xdsl_to_qml_measurement, xdsl_to_qml_op, ) +from xdsl.dialects import builtin, func, scf +from xdsl.ir import Block, Operation, Region, SSAValue class ConstructCircuitDAG: @@ -90,16 +89,19 @@ def _visit_block(self, block: Block) -> None: @_visit_operation.register def _unitary( self, - op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, + op: quantum.CustomOp + | quantum.GlobalPhaseOp + | quantum.QubitUnitaryOp + | quantum.MultiRZOp, ) -> None: """Generic handler for unitary gates.""" qml_op = xdsl_to_qml_op(op) # Build node on graph self.dag_builder.add_node( - id=f"node_{id(op)}", + uid=f"node_{id(op)}", label=str(qml_op), - cluster_id=self._cluster_stack[-1], + cluster_uid=self._cluster_uid_stack[-1], ) # ===================== @@ -113,9 +115,9 @@ def _state_op(self, op: quantum.StateOp) -> None: meas = xdsl_to_qml_measurement(op) # Build node on graph self.dag_builder.add_node( - id=f"node_{id(op)}", + uid=f"node_{id(op)}", label=str(meas), - cluster_id=self._cluster_stack[-1], + cluster_uid=self._cluster_uid_stack[-1], ) @_visit_operation.register @@ -129,9 +131,9 @@ def _statistical_measurement_ops( meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) # Build node on graph self.dag_builder.add_node( - id=f"node_{id(op)}", + uid=f"node_{id(op)}", label=str(meas), - cluster_id=self._cluster_stack[-1], + cluster_uid=self._cluster_uid_stack[-1], ) @_visit_operation.register @@ -141,9 +143,9 @@ def _projective_measure_op(self, op: quantum.MeasureOp) -> None: meas = xdsl_to_qml_measurement(op) # Build node on graph self.dag_builder.add_node( - id=f"node_{id(op)}", + uid=f"node_{id(op)}", label=str(meas), - cluster_id=self._cluster_stack[-1], + cluster_uid=self._cluster_uid_stack[-1], ) # ============= @@ -187,7 +189,9 @@ def _while_op(self, operation: scf.WhileOp) -> None: @_visit_operation.register def _if_op(self, operation: scf.IfOp): """Handles the scf.IfOp operation.""" - flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op(operation) + flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op( + operation + ) uid = f"cluster_{id(operation)}" self.dag_builder.add_cluster( From 5e83fe4fef603dfd6461b11ae216d83f0c27a5ce Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 2 Dec 2025 16:25:03 -0500 Subject: [PATCH 49/75] fix --- .../visualization/construct_circuit_dag.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 4f34e7758c..ce6cb0b1d4 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -16,14 +16,15 @@ from functools import singledispatchmethod +from xdsl.dialects import builtin, func, scf +from xdsl.ir import Block, Operation, Region, SSAValue + from catalyst.python_interface.dialects import catalyst, quantum from catalyst.python_interface.visualization.dag_builder import DAGBuilder from catalyst.python_interface.visualization.xdsl_conversion import ( xdsl_to_qml_measurement, xdsl_to_qml_op, ) -from xdsl.dialects import builtin, func, scf -from xdsl.ir import Block, Operation, Region, SSAValue class ConstructCircuitDAG: @@ -89,10 +90,7 @@ def _visit_block(self, block: Block) -> None: @_visit_operation.register def _unitary( self, - op: quantum.CustomOp - | quantum.GlobalPhaseOp - | quantum.QubitUnitaryOp - | quantum.MultiRZOp, + op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, ) -> None: """Generic handler for unitary gates.""" @@ -189,9 +187,7 @@ def _while_op(self, operation: scf.WhileOp) -> None: @_visit_operation.register def _if_op(self, operation: scf.IfOp): """Handles the scf.IfOp operation.""" - flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op( - operation - ) + flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op(operation) uid = f"cluster_{id(operation)}" self.dag_builder.add_cluster( From d5a0611e27bf40f389b3929f3752b80ad8ef5e34 Mon Sep 17 00:00:00 2001 From: Andrija Paurevic <46359773+andrijapau@users.noreply.github.com> Date: Tue, 2 Dec 2025 16:50:34 -0500 Subject: [PATCH 50/75] Apply suggestion from @andrijapau --- .../python_interface/visualization/construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index ce6cb0b1d4..1f27dae700 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -21,7 +21,7 @@ from catalyst.python_interface.dialects import catalyst, quantum from catalyst.python_interface.visualization.dag_builder import DAGBuilder -from catalyst.python_interface.visualization.xdsl_conversion import ( +from catalyst.python_interface.inspection.xdsl_conversion import ( xdsl_to_qml_measurement, xdsl_to_qml_op, ) From 98bf7ea6b1cfbd9798211ceb06cee982127d58cb Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 2 Dec 2025 17:54:35 -0500 Subject: [PATCH 51/75] clean-up --- .../visualization/construct_circuit_dag.py | 36 +++++++++++++------ 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 1f27dae700..92064d9a11 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -16,15 +16,14 @@ from functools import singledispatchmethod -from xdsl.dialects import builtin, func, scf -from xdsl.ir import Block, Operation, Region, SSAValue - from catalyst.python_interface.dialects import catalyst, quantum -from catalyst.python_interface.visualization.dag_builder import DAGBuilder from catalyst.python_interface.inspection.xdsl_conversion import ( xdsl_to_qml_measurement, xdsl_to_qml_op, ) +from catalyst.python_interface.visualization.dag_builder import DAGBuilder +from xdsl.dialects import builtin, func, scf +from xdsl.ir import Block, Operation, Region, SSAValue class ConstructCircuitDAG: @@ -90,14 +89,20 @@ def _visit_block(self, block: Block) -> None: @_visit_operation.register def _unitary( self, - op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, + op: quantum.CustomOp + | quantum.GlobalPhaseOp + | quantum.QubitUnitaryOp + | quantum.MultiRZOp, ) -> None: """Generic handler for unitary gates.""" + # Create PennyLane instance qml_op = xdsl_to_qml_op(op) - # Build node on graph + + # Add node to current cluster + node_uid = f"node_{id(op)}" self.dag_builder.add_node( - uid=f"node_{id(op)}", + uid=node_uid, label=str(qml_op), cluster_uid=self._cluster_uid_stack[-1], ) @@ -111,11 +116,14 @@ def _state_op(self, op: quantum.StateOp) -> None: """Handler for the terminal state measurement operation.""" meas = xdsl_to_qml_measurement(op) + node_uid = f"node_{id(op)}" # Build node on graph self.dag_builder.add_node( - uid=f"node_{id(op)}", + uid=node_uid, label=str(meas), cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", ) @_visit_operation.register @@ -127,11 +135,14 @@ def _statistical_measurement_ops( obs_op = op.obs.owner meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) + node_uid = f"node_{id(op)}" # Build node on graph self.dag_builder.add_node( - uid=f"node_{id(op)}", + uid=node_uid, label=str(meas), cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", ) @_visit_operation.register @@ -139,9 +150,10 @@ def _projective_measure_op(self, op: quantum.MeasureOp) -> None: """Handler for the single-qubit projective measurement operation.""" meas = xdsl_to_qml_measurement(op) + node_uid = f"node_{id(op)}" # Build node on graph self.dag_builder.add_node( - uid=f"node_{id(op)}", + uid=node_uid, label=str(meas), cluster_uid=self._cluster_uid_stack[-1], ) @@ -187,7 +199,9 @@ def _while_op(self, operation: scf.WhileOp) -> None: @_visit_operation.register def _if_op(self, operation: scf.IfOp): """Handles the scf.IfOp operation.""" - flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op(operation) + flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op( + operation + ) uid = f"cluster_{id(operation)}" self.dag_builder.add_cluster( From 818448c1003728c04956a5837c3133909bd34c05 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 2 Dec 2025 17:55:39 -0500 Subject: [PATCH 52/75] clean up --- .../visualization/construct_circuit_dag.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 92064d9a11..edbfe44ff3 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -115,9 +115,11 @@ def _unitary( def _state_op(self, op: quantum.StateOp) -> None: """Handler for the terminal state measurement operation.""" + # Create PennyLane instance meas = xdsl_to_qml_measurement(op) + + # Add node to current cluster node_uid = f"node_{id(op)}" - # Build node on graph self.dag_builder.add_node( uid=node_uid, label=str(meas), @@ -133,10 +135,12 @@ def _statistical_measurement_ops( ) -> None: """Handler for statistical measurement operations.""" + # Create PennyLane instance obs_op = op.obs.owner meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) + + # Add node to current cluster node_uid = f"node_{id(op)}" - # Build node on graph self.dag_builder.add_node( uid=node_uid, label=str(meas), @@ -149,9 +153,11 @@ def _statistical_measurement_ops( def _projective_measure_op(self, op: quantum.MeasureOp) -> None: """Handler for the single-qubit projective measurement operation.""" + # Create PennyLane instance meas = xdsl_to_qml_measurement(op) + + # Add node to current cluster node_uid = f"node_{id(op)}" - # Build node on graph self.dag_builder.add_node( uid=node_uid, label=str(meas), From b924c43f6b04ea14eb86a225405328851793fbde Mon Sep 17 00:00:00 2001 From: andrijapau Date: Tue, 2 Dec 2025 17:56:08 -0500 Subject: [PATCH 53/75] format --- .../visualization/construct_circuit_dag.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index edbfe44ff3..481bd4d662 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -16,14 +16,15 @@ from functools import singledispatchmethod +from xdsl.dialects import builtin, func, scf +from xdsl.ir import Block, Operation, Region, SSAValue + from catalyst.python_interface.dialects import catalyst, quantum from catalyst.python_interface.inspection.xdsl_conversion import ( xdsl_to_qml_measurement, xdsl_to_qml_op, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder -from xdsl.dialects import builtin, func, scf -from xdsl.ir import Block, Operation, Region, SSAValue class ConstructCircuitDAG: @@ -89,10 +90,7 @@ def _visit_block(self, block: Block) -> None: @_visit_operation.register def _unitary( self, - op: quantum.CustomOp - | quantum.GlobalPhaseOp - | quantum.QubitUnitaryOp - | quantum.MultiRZOp, + op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, ) -> None: """Generic handler for unitary gates.""" @@ -118,7 +116,7 @@ def _state_op(self, op: quantum.StateOp) -> None: # Create PennyLane instance meas = xdsl_to_qml_measurement(op) - # Add node to current cluster + # Add node to current cluster node_uid = f"node_{id(op)}" self.dag_builder.add_node( uid=node_uid, @@ -139,7 +137,7 @@ def _statistical_measurement_ops( obs_op = op.obs.owner meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) - # Add node to current cluster + # Add node to current cluster node_uid = f"node_{id(op)}" self.dag_builder.add_node( uid=node_uid, @@ -156,7 +154,7 @@ def _projective_measure_op(self, op: quantum.MeasureOp) -> None: # Create PennyLane instance meas = xdsl_to_qml_measurement(op) - # Add node to current cluster + # Add node to current cluster node_uid = f"node_{id(op)}" self.dag_builder.add_node( uid=node_uid, @@ -205,9 +203,7 @@ def _while_op(self, operation: scf.WhileOp) -> None: @_visit_operation.register def _if_op(self, operation: scf.IfOp): """Handles the scf.IfOp operation.""" - flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op( - operation - ) + flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op(operation) uid = f"cluster_{id(operation)}" self.dag_builder.add_cluster( From 9c74b6a3e74ef09f790189b9df1d6b90670fa0f3 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 4 Dec 2025 12:29:09 -0500 Subject: [PATCH 54/75] update id --- .../visualization/construct_circuit_dag.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 4888f4770d..0397a95c84 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -104,12 +104,13 @@ def _unitary( qml_op = xdsl_to_qml_op(op) # Add node to current cluster - node_uid = f"node_{id(op)}" + node_uid = f"node{self._node_uid_counter}" self.dag_builder.add_node( uid=node_uid, label=str(qml_op), cluster_uid=self._cluster_uid_stack[-1], ) + self._node_uid_counter += 1 # ===================== # QUANTUM MEASUREMENTS @@ -123,7 +124,7 @@ def _state_op(self, op: quantum.StateOp) -> None: meas = xdsl_to_qml_measurement(op) # Add node to current cluster - node_uid = f"node_{id(op)}" + node_uid = f"node{self._node_uid_counter}" self.dag_builder.add_node( uid=node_uid, label=str(meas), @@ -131,6 +132,7 @@ def _state_op(self, op: quantum.StateOp) -> None: fillcolor="lightpink", color="lightpink3", ) + self._node_uid_counter += 1 @_visit_operation.register def _statistical_measurement_ops( @@ -144,7 +146,7 @@ def _statistical_measurement_ops( meas = xdsl_to_qml_measurement(op, xdsl_to_qml_measurement(obs_op)) # Add node to current cluster - node_uid = f"node_{id(op)}" + node_uid = f"node{self._node_uid_counter}" self.dag_builder.add_node( uid=node_uid, label=str(meas), @@ -152,6 +154,7 @@ def _statistical_measurement_ops( fillcolor="lightpink", color="lightpink3", ) + self._node_uid_counter += 1 @_visit_operation.register def _projective_measure_op(self, op: quantum.MeasureOp) -> None: @@ -161,12 +164,13 @@ def _projective_measure_op(self, op: quantum.MeasureOp) -> None: meas = xdsl_to_qml_measurement(op) # Add node to current cluster - node_uid = f"node_{id(op)}" + node_uid = f"node{self._node_uid_counter}" self.dag_builder.add_node( uid=node_uid, label=str(meas), cluster_uid=self._cluster_uid_stack[-1], ) + self._node_uid_counter += 1 # ============= # CONTROL FLOW From dd1df44f06549443a0d99e8455fdcd6c7c99cdf1 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 4 Dec 2025 13:04:05 -0500 Subject: [PATCH 55/75] adjust test --- .../visualization/test_construct_circuit_dag.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index d563224e5e..84c34fa1eb 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -283,10 +283,11 @@ def my_workflow(): # Assert null qubit device node is inside my_qnode2 cluster assert graph_clusters["cluster2"]["cluster_label"] == "my_qnode2" - assert graph_nodes["node1"]["parent_cluster_uid"] == "cluster2" + # NOTE: node1 is the qml.H(0) in my_qnode1 + assert graph_nodes["node2"]["parent_cluster_uid"] == "cluster2" # Assert label is as expected - assert graph_nodes["node1"]["label"] == "LightningSimulator" + assert graph_nodes["node2"]["label"] == "LightningSimulator" class TestForOp: From 3175efa4473bfc6bbc37762157207b39bc12c0c1 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 4 Dec 2025 14:59:58 -0500 Subject: [PATCH 56/75] add tests --- .../test_construct_circuit_dag.py | 228 +++++++++++++++++- 1 file changed, 227 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 3a41b75f78..15cdad5c8c 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -539,5 +539,231 @@ def my_workflow(x, y): assert clusters["cluster6"]["parent_cluster_uid"] == "cluster4" # Check nested if / else is within the first if cluster - assert clusters["cluster7"]["node_label"] == "else" assert clusters["cluster7"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster7"]["node_label"] == "else" + + +class TestCreateStaticOperatorNodes: + """Tests that operators with static parameters can be created and visualized as nodes.""" + + @pytest.mark.unit + @pytest.mark.parametrize("op", [qml.H(0), qml.X(0), qml.SWAP([0, 1])]) + def test_custom_op(self, op): + """Tests that the CustomOp operation node can be created and visualized.""" + + # Build module with only a CustomOp + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.apply(op) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == str(op) + + @pytest.mark.unit + @pytest.mark.parametrize( + "op", + [ + qml.GlobalPhase(0.5), + qml.GlobalPhase(0.5, wires=0), + qml.GlobalPhase(0.5, wires=[0, 1]), + ], + ) + def test_global_phase_op(self, op): + """Test that GlobalPhase can be handled.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.apply(op) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == str(op) + + @pytest.mark.unit + def test_qubit_unitary_op(self): + """Test that QubitUnitary operations can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.QubitUnitary([[0, 1], [1, 0]], wires=0) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == str(qml.QubitUnitary([[0, 1], [1, 0]], wires=0)) + + @pytest.mark.unit + def test_multi_rz_op(self): + """Test that MultiRZ operations can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + qml.MultiRZ(0.5, wires=[0]) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == str(qml.MultiRZ(0.5, wires=[0])) + + +class TestCreateStaticMeasurementNodes: + """Tests that measurements with static parameters can be created and visualized as nodes.""" + + @pytest.mark.unit + def test_state_op(self): + """Test that qml.state can be handled.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return qml.state() + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == str(qml.state()) + + @pytest.mark.unit + @pytest.mark.parametrize("meas_fn", [qml.expval, qml.var]) + def test_expval_var_measurement_op(self, meas_fn): + """Test that statistical measurement operators can be captured as nodes.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return meas_fn(qml.Z(0)) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + # Ensure DAG only has one node + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == str(meas_fn(qml.Z(0))) + + @pytest.mark.unit + def test_probs_measurement_op(self): + """Tests that the probs measurement function can be captured as a node.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + return qml.probs() + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == str(qml.probs()) + + @pytest.mark.unit + def test_sample_measurement_op(self): + """Tests that the sample measurement function can be captured as a node.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.set_shots(10) + @qml.qnode(dev) + def my_circuit(): + return qml.sample() + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == str(qml.sample()) + + @pytest.mark.unit + def test_projective_measurement_op(self): + """Test that projective measurements can be captured as nodes.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + measure(0) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.get_nodes() + assert len(nodes) == 2 # Device node + operator + + assert "MidMeasure" in nodes["node1"]["label"] From 194b3579303d4703bdd437a45d2225e7ed18853e Mon Sep 17 00:00:00 2001 From: andrijapau Date: Thu, 4 Dec 2025 15:40:18 -0500 Subject: [PATCH 57/75] fix get_nodes -> nodes --- .../test_construct_circuit_dag.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 15cdad5c8c..5438f9604d 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -567,7 +567,7 @@ def my_circuit(): utility.construct(module) # Ensure DAG only has one node - nodes = utility.dag_builder.get_nodes() + nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator assert nodes["node1"]["label"] == str(op) @@ -599,7 +599,7 @@ def my_circuit(): utility.construct(module) # Ensure DAG only has one node - nodes = utility.dag_builder.get_nodes() + nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator assert nodes["node1"]["label"] == str(op) @@ -622,7 +622,7 @@ def my_circuit(): utility.construct(module) # Ensure DAG only has one node - nodes = utility.dag_builder.get_nodes() + nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator assert nodes["node1"]["label"] == str(qml.QubitUnitary([[0, 1], [1, 0]], wires=0)) @@ -645,7 +645,7 @@ def my_circuit(): utility.construct(module) # Ensure DAG only has one node - nodes = utility.dag_builder.get_nodes() + nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator assert nodes["node1"]["label"] == str(qml.MultiRZ(0.5, wires=[0])) @@ -672,7 +672,7 @@ def my_circuit(): utility.construct(module) # Ensure DAG only has one node - nodes = utility.dag_builder.get_nodes() + nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator assert nodes["node1"]["label"] == str(qml.state()) @@ -696,7 +696,7 @@ def my_circuit(): utility.construct(module) # Ensure DAG only has one node - nodes = utility.dag_builder.get_nodes() + nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator assert nodes["node1"]["label"] == str(meas_fn(qml.Z(0))) @@ -718,7 +718,7 @@ def my_circuit(): utility = ConstructCircuitDAG(FakeDAGBuilder()) utility.construct(module) - nodes = utility.dag_builder.get_nodes() + nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator assert nodes["node1"]["label"] == str(qml.probs()) @@ -741,7 +741,7 @@ def my_circuit(): utility = ConstructCircuitDAG(FakeDAGBuilder()) utility.construct(module) - nodes = utility.dag_builder.get_nodes() + nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator assert nodes["node1"]["label"] == str(qml.sample()) @@ -763,7 +763,7 @@ def my_circuit(): utility = ConstructCircuitDAG(FakeDAGBuilder()) utility.construct(module) - nodes = utility.dag_builder.get_nodes() + nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator assert "MidMeasure" in nodes["node1"]["label"] From fac615090dcdd1a05761f4174924a9ae8c6712f4 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 09:11:07 -0500 Subject: [PATCH 58/75] fix measure --- .../visualization/test_construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index eae06e0420..146465d177 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -755,7 +755,7 @@ def test_projective_measurement_op(self): @qml.qjit(autograph=True, target="mlir") @qml.qnode(dev) def my_circuit(): - measure(0) + qml.measure(0) module = my_circuit() From 9bdf8da73ae92f606725ace9f2759213b6643cd8 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 09:19:32 -0500 Subject: [PATCH 59/75] fix globalphase test --- .../visualization/test_construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 146465d177..4a554ad5d2 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -602,7 +602,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == str(op) + assert nodes["node1"]["label"] == "GlobalPhase(0.5, wires=[])" @pytest.mark.unit def test_qubit_unitary_op(self): From b9479b194fe7da71c414e433be17e3472d105022 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 09:21:28 -0500 Subject: [PATCH 60/75] fix sample test --- .../visualization/test_construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 4a554ad5d2..a2261fb49c 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -744,7 +744,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == str(qml.sample()) + assert nodes["node1"]["label"] == "sample(wires=[])" @pytest.mark.unit def test_projective_measurement_op(self): From 60e83c4730c80cdb64633be7b92f1685c0f30e47 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 09:46:38 -0500 Subject: [PATCH 61/75] fix sample --- .../visualization/construct_circuit_dag.py | 41 ++++++++++++++++--- .../test_construct_circuit_dag.py | 14 +++++-- 2 files changed, 46 insertions(+), 9 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 7deac56ba5..b0ce7c5c3c 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -16,15 +16,14 @@ from functools import singledispatchmethod -from xdsl.dialects import builtin, func, scf -from xdsl.ir import Block, Operation, Region, SSAValue - from catalyst.python_interface.dialects import catalyst, quantum from catalyst.python_interface.inspection.xdsl_conversion import ( xdsl_to_qml_measurement, xdsl_to_qml_op, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder +from xdsl.dialects import builtin, func, scf +from xdsl.ir import Block, Operation, Region, SSAValue class ConstructCircuitDAG: @@ -96,7 +95,10 @@ def _visit_block(self, block: Block) -> None: @_visit_operation.register def _unitary( self, - op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, + op: quantum.CustomOp + | quantum.GlobalPhaseOp + | quantum.QubitUnitaryOp + | quantum.MultiRZOp, ) -> None: """Generic handler for unitary gates.""" @@ -156,6 +158,29 @@ def _statistical_measurement_ops( ) self._node_uid_counter += 1 + @_visit_operation.register + def _sample_op( + self, + op: quantum.SampleOp, + ) -> None: + """Handler for sample operations.""" + + # Create PennyLane instance + obs_op = op.obs.owner + wires = xdsl_to_qml_measurement(obs_op) + meas = xdsl_to_qml_measurement(op, wires=None if wires == [] else wires) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=str(meas), + cluster_uid=self._cluster_uid_stack[-1], + fillcolor="lightpink", + color="lightpink3", + ) + self._node_uid_counter += 1 + @_visit_operation.register def _projective_measure_op(self, op: quantum.MeasureOp) -> None: """Handler for the single-qubit projective measurement operation.""" @@ -216,7 +241,9 @@ def _while_op(self, operation: scf.WhileOp) -> None: @_visit_operation.register def _if_op(self, operation: scf.IfOp): """Handles the scf.IfOp operation.""" - flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op(operation) + flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op( + operation + ) uid = f"cluster{self._cluster_uid_counter}" self.dag_builder.add_cluster( @@ -295,7 +322,9 @@ def _func_op(self, operation: func.FuncOp) -> None: label = "qjit" uid = f"cluster{self._cluster_uid_counter}" - parent_cluster_uid = None if self._cluster_uid_stack == [] else self._cluster_uid_stack[-1] + parent_cluster_uid = ( + None if self._cluster_uid_stack == [] else self._cluster_uid_stack[-1] + ) self.dag_builder.add_cluster( uid, label=label, diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index a2261fb49c..4ab45a75e0 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -724,7 +724,15 @@ def my_circuit(): assert nodes["node1"]["label"] == str(qml.probs()) @pytest.mark.unit - def test_sample_measurement_op(self): + @pytest.mark.parametrize( + "op", + [ + qml.sample(), + qml.sample(wires=0), + qml.sample(wires=[0, 1]), + ], + ) + def test_sample_measurement_op(self, op): """Tests that the sample measurement function can be captured as a node.""" dev = qml.device("null.qubit", wires=1) @@ -733,7 +741,7 @@ def test_sample_measurement_op(self): @qml.set_shots(10) @qml.qnode(dev) def my_circuit(): - return qml.sample() + return op module = my_circuit() @@ -744,7 +752,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == "sample(wires=[])" + assert nodes["node1"]["label"] == str(op) @pytest.mark.unit def test_projective_measurement_op(self): From ebd108e35bf385230c7dd9e680f95da0cd021483 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 09:47:05 -0500 Subject: [PATCH 62/75] format --- .../visualization/construct_circuit_dag.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index b0ce7c5c3c..fad9a4fef9 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -16,14 +16,15 @@ from functools import singledispatchmethod +from xdsl.dialects import builtin, func, scf +from xdsl.ir import Block, Operation, Region, SSAValue + from catalyst.python_interface.dialects import catalyst, quantum from catalyst.python_interface.inspection.xdsl_conversion import ( xdsl_to_qml_measurement, xdsl_to_qml_op, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder -from xdsl.dialects import builtin, func, scf -from xdsl.ir import Block, Operation, Region, SSAValue class ConstructCircuitDAG: @@ -95,10 +96,7 @@ def _visit_block(self, block: Block) -> None: @_visit_operation.register def _unitary( self, - op: quantum.CustomOp - | quantum.GlobalPhaseOp - | quantum.QubitUnitaryOp - | quantum.MultiRZOp, + op: quantum.CustomOp | quantum.GlobalPhaseOp | quantum.QubitUnitaryOp | quantum.MultiRZOp, ) -> None: """Generic handler for unitary gates.""" @@ -241,9 +239,7 @@ def _while_op(self, operation: scf.WhileOp) -> None: @_visit_operation.register def _if_op(self, operation: scf.IfOp): """Handles the scf.IfOp operation.""" - flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op( - operation - ) + flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op(operation) uid = f"cluster{self._cluster_uid_counter}" self.dag_builder.add_cluster( @@ -322,9 +318,7 @@ def _func_op(self, operation: func.FuncOp) -> None: label = "qjit" uid = f"cluster{self._cluster_uid_counter}" - parent_cluster_uid = ( - None if self._cluster_uid_stack == [] else self._cluster_uid_stack[-1] - ) + parent_cluster_uid = None if self._cluster_uid_stack == [] else self._cluster_uid_stack[-1] self.dag_builder.add_cluster( uid, label=label, From bfd1d103a1044082f50405828fb81aca57df1a6b Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 09:50:21 -0500 Subject: [PATCH 63/75] add dev comment --- .../python_interface/visualization/construct_circuit_dag.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index fad9a4fef9..4b03ba20ab 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -137,7 +137,7 @@ def _state_op(self, op: quantum.StateOp) -> None: @_visit_operation.register def _statistical_measurement_ops( self, - op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp | quantum.SampleOp, + op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp, ) -> None: """Handler for statistical measurement operations.""" @@ -165,6 +165,9 @@ def _sample_op( # Create PennyLane instance obs_op = op.obs.owner + + # TODO: This doesn't logically make sense, but quantum.compbasis + # is obs_op and function below just pulls out the static wires wires = xdsl_to_qml_measurement(obs_op) meas = xdsl_to_qml_measurement(op, wires=None if wires == [] else wires) From aec9f83aae22e8378836ea137dc15ce82fada766 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 10:23:11 -0500 Subject: [PATCH 64/75] fix code --- .../python_interface/visualization/construct_circuit_dag.py | 4 ++-- .../visualization/test_construct_circuit_dag.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 4b03ba20ab..1f3f5cb75f 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -137,7 +137,7 @@ def _state_op(self, op: quantum.StateOp) -> None: @_visit_operation.register def _statistical_measurement_ops( self, - op: quantum.ExpvalOp | quantum.VarianceOp | quantum.ProbsOp, + op: quantum.ExpvalOp | quantum.VarianceOp, ) -> None: """Handler for statistical measurement operations.""" @@ -159,7 +159,7 @@ def _statistical_measurement_ops( @_visit_operation.register def _sample_op( self, - op: quantum.SampleOp, + op: quantum.SampleOp | quantum.ProbsOp, ) -> None: """Handler for sample operations.""" diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 4ab45a75e0..509163bb51 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -27,6 +27,7 @@ from xdsl.dialects.builtin import ModuleOp from xdsl.ir.core import Block, Region +from catalyst import measure from catalyst.python_interface.conversion import xdsl_from_qjit from catalyst.python_interface.visualization.construct_circuit_dag import ( ConstructCircuitDAG, @@ -763,7 +764,7 @@ def test_projective_measurement_op(self): @qml.qjit(autograph=True, target="mlir") @qml.qnode(dev) def my_circuit(): - qml.measure(0) + measure(0) module = my_circuit() From c36a43ebab41c7f5654a83455e277151e10a617d Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 10:25:03 -0500 Subject: [PATCH 65/75] fix naming --- .../python_interface/visualization/construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 1f3f5cb75f..3f65659898 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -157,7 +157,7 @@ def _statistical_measurement_ops( self._node_uid_counter += 1 @_visit_operation.register - def _sample_op( + def _visit_sample_and_probs_ops( self, op: quantum.SampleOp | quantum.ProbsOp, ) -> None: From 6c1d825d7cba21f2f0c49c04047aaaf72edde24f Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 10:31:32 -0500 Subject: [PATCH 66/75] add expected error messages --- .../test_construct_circuit_dag.py | 37 ++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 509163bb51..ae550ea3f2 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -733,7 +733,7 @@ def my_circuit(): qml.sample(wires=[0, 1]), ], ) - def test_sample_measurement_op(self, op): + def test_valid_sample_measurement_op(self, op): """Tests that the sample measurement function can be captured as a node.""" dev = qml.device("null.qubit", wires=1) @@ -755,6 +755,41 @@ def my_circuit(): assert nodes["node1"]["label"] == str(op) + @pytest.mark.unit + @pytest.mark.parametrize( + "op, error_type, error_message", + [ + [ + qml.sample(op=qml.Z(0)), + CompileError, + "Only expectation value and variance measurements can accept observables with Catalyst", + ], + [ + qml.sample(op=qml.Z(0), wires=[0]), + ValueError, + "Cannot specify the wires to sample if an observable is provided. The wires to sample will be determined directly from the observable.", + ], + ], + ) + def test_invalid_sample_measurement_op(self, op, error_type, error_message): + """Makes sure that invalid sample operations hold true.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.set_shots(10) + @qml.qnode(dev) + def my_circuit(): + return op + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + with pytest.raises(error_type, match=error_message): + utility.construct(module) + @pytest.mark.unit def test_projective_measurement_op(self): """Test that projective measurements can be captured as nodes.""" From 7b56f898fbb5eb8827938abb294afcd602aa8acc Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 10:32:34 -0500 Subject: [PATCH 67/75] fix test for probs --- .../visualization/test_construct_circuit_dag.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index ae550ea3f2..06df24473a 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -703,7 +703,15 @@ def my_circuit(): assert nodes["node1"]["label"] == str(meas_fn(qml.Z(0))) @pytest.mark.unit - def test_probs_measurement_op(self): + @pytest.mark.parametrize( + "op", + [ + qml.probs(), + qml.probs(wires=0), + qml.probs(wires=[0, 1]), + ], + ) + def test_probs_measurement_op(self, op): """Tests that the probs measurement function can be captured as a node.""" dev = qml.device("null.qubit", wires=1) @@ -711,7 +719,7 @@ def test_probs_measurement_op(self): @qml.qjit(autograph=True, target="mlir") @qml.qnode(dev) def my_circuit(): - return qml.probs() + return op module = my_circuit() @@ -722,7 +730,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == str(qml.probs()) + assert nodes["node1"]["label"] == str(op) @pytest.mark.unit @pytest.mark.parametrize( From 4f8335589d1d99b8ab5c9a5e23ab4bc28db87814 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 10:32:54 -0500 Subject: [PATCH 68/75] format --- .../visualization/test_construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 06df24473a..f3f9f289a8 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -730,7 +730,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == str(op) + assert nodes["node1"]["label"] == str(op) @pytest.mark.unit @pytest.mark.parametrize( From 9f6a1311521f15d19d43babc2fd90313c2208f6b Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 10:56:08 -0500 Subject: [PATCH 69/75] add compileerror --- .../python_interface/visualization/test_construct_circuit_dag.py | 1 + 1 file changed, 1 insertion(+) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index f3f9f289a8..bb70d4bfdd 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -33,6 +33,7 @@ ConstructCircuitDAG, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder +from catalyst.utils.exceptions import CompileError class FakeDAGBuilder(DAGBuilder): From 9855985257b0da3eca0f7717f1cff816cb45bfb3 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 11:42:06 -0500 Subject: [PATCH 70/75] update --- .../test_construct_circuit_dag.py | 35 ------------------- 1 file changed, 35 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index bb70d4bfdd..493cfffbe1 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -764,41 +764,6 @@ def my_circuit(): assert nodes["node1"]["label"] == str(op) - @pytest.mark.unit - @pytest.mark.parametrize( - "op, error_type, error_message", - [ - [ - qml.sample(op=qml.Z(0)), - CompileError, - "Only expectation value and variance measurements can accept observables with Catalyst", - ], - [ - qml.sample(op=qml.Z(0), wires=[0]), - ValueError, - "Cannot specify the wires to sample if an observable is provided. The wires to sample will be determined directly from the observable.", - ], - ], - ) - def test_invalid_sample_measurement_op(self, op, error_type, error_message): - """Makes sure that invalid sample operations hold true.""" - - dev = qml.device("null.qubit", wires=1) - - @xdsl_from_qjit - @qml.qjit(autograph=True, target="mlir") - @qml.set_shots(10) - @qml.qnode(dev) - def my_circuit(): - return op - - module = my_circuit() - - # Construct DAG - utility = ConstructCircuitDAG(FakeDAGBuilder()) - with pytest.raises(error_type, match=error_message): - utility.construct(module) - @pytest.mark.unit def test_projective_measurement_op(self): """Test that projective measurements can be captured as nodes.""" From 5e3a6d2fb9fe1d47012cb8b36121f3288a17a5b8 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 16:16:46 -0500 Subject: [PATCH 71/75] re-work node labels to be records with ports --- .../visualization/construct_circuit_dag.py | 39 +++++++++-- .../test_construct_circuit_dag.py | 69 +++++++++++++++---- 2 files changed, 90 insertions(+), 18 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index bc6ec4dada..09dcee720f 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -14,8 +14,10 @@ """Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module.""" -from functools import singledispatchmethod +from functools import singledispatch, singledispatchmethod +from pennylane.measurements import MeasurementProcess +from pennylane.operation import Operator from xdsl.dialects import builtin, func, scf from xdsl.ir import Block, Operation, Region, SSAValue @@ -107,8 +109,10 @@ def _unitary( node_uid = f"node{self._node_uid_counter}" self.dag_builder.add_node( uid=node_uid, - label=str(qml_op), + label=get_label(qml_op), cluster_uid=self._cluster_uid_stack[-1], + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", ) self._node_uid_counter += 1 @@ -127,7 +131,7 @@ def _state_op(self, op: quantum.StateOp) -> None: node_uid = f"node{self._node_uid_counter}" self.dag_builder.add_node( uid=node_uid, - label=str(meas), + label=get_label(meas), cluster_uid=self._cluster_uid_stack[-1], fillcolor="lightpink", color="lightpink3", @@ -149,7 +153,7 @@ def _statistical_measurement_ops( node_uid = f"node{self._node_uid_counter}" self.dag_builder.add_node( uid=node_uid, - label=str(meas), + label=get_label(meas), cluster_uid=self._cluster_uid_stack[-1], fillcolor="lightpink", color="lightpink3", @@ -175,7 +179,7 @@ def _visit_sample_and_probs_ops( node_uid = f"node{self._node_uid_counter}" self.dag_builder.add_node( uid=node_uid, - label=str(meas), + label=get_label(meas), cluster_uid=self._cluster_uid_stack[-1], fillcolor="lightpink", color="lightpink3", @@ -193,7 +197,7 @@ def _projective_measure_op(self, op: quantum.MeasureOp) -> None: node_uid = f"node{self._node_uid_counter}" self.dag_builder.add_node( uid=node_uid, - label=str(meas), + label=get_label(meas), cluster_uid=self._cluster_uid_stack[-1], ) self._node_uid_counter += 1 @@ -367,3 +371,26 @@ def _flatten_if_op(op: scf.IfOp) -> list[tuple[SSAValue | None, Region]]: # with no SSAValue flattened_op.extend([(None, else_region)]) return flattened_op + + +@singledispatch +def get_label(op: Operator | MeasurementProcess) -> str: + """Gets the appropriate label for a PennyLane object.""" + return str(op) + + +@get_label.register +def _operator(op: Operator) -> str: + """Returns the appropriate label for an xDSL operation.""" + wires = list(op.wires.labels) + if wires == []: + wires_str = "all" + else: + wires_str = f"[{', '.join(map(str, wires))}]" + return f" {op.name}| {wires_str}" + + +@get_label.register +def _mp(mp: MeasurementProcess) -> str: + """Returns the appropriate label for an xDSL operation.""" + return str(mp) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 50bb1fe3d4..48868a1d26 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -22,17 +22,17 @@ # pylint: disable=wrong-import-position # This import needs to be after pytest in order to prevent ImportErrors import pennylane as qml -from xdsl.dialects import test -from xdsl.dialects.builtin import ModuleOp -from xdsl.ir.core import Block, Region - from catalyst import measure from catalyst.python_interface.conversion import xdsl_from_qjit from catalyst.python_interface.visualization.construct_circuit_dag import ( ConstructCircuitDAG, + get_label, ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder from catalyst.utils.exceptions import CompileError +from xdsl.dialects import test +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir.core import Block, Region class FakeDAGBuilder(DAGBuilder): @@ -545,6 +545,48 @@ def my_workflow(x, y): assert clusters["cluster7"]["node_label"] == "else" +class TestGetLabel: + """Tests the get_label utility.""" + + @pytest.mark.unit + @pytest.mark.parametrize( + "op", [qml.H(0), qml.QubitUnitary([[0, 1], [1, 0]], 0), qml.SWAP([0, 1])] + ) + def test_standard_operator(self, op): + """Tests against an operator instance.""" + wires = list(op.wires.labels) + if wires == []: + wires_str = "all" + else: + wires_str = f"[{', '.join(map(str, wires))}]" + + assert get_label(op) == f" {op.name}| {wires_str}" + + def test_global_phase_operator(self, op): + """Tests against a GlobalPhase operator instance.""" + assert get_label(qml.GlobalPhase(0.5)) == f" {op.name}| all" + + @pytest.mark.unit + @pytest.mark.parametrize( + "meas", + [ + qml.state(), + qml.expval(qml.Z(0)), + qml.var(qml.Z(0)), + qml.probs(), + qml.probs(wires=0), + qml.probs(wires=[0, 1]), + qml.sample(), + qml.sample(wires=0), + qml.sample(wires=[0, 1]), + ], + ) + def test_standard_measurement(self, meas): + """Tests against an operator instance.""" + + assert get_label(meas) == str(meas) + + class TestCreateStaticOperatorNodes: """Tests that operators with static parameters can be created and visualized as nodes.""" @@ -572,7 +614,8 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == str(op) + # Make sure label has relevant info + assert nodes["node1"]["label"] == get_label(op) @pytest.mark.unit @pytest.mark.parametrize( @@ -604,7 +647,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == "GlobalPhase(0.5, wires=[])" + assert nodes["node1"]["label"] == get_label(op) @pytest.mark.unit def test_qubit_unitary_op(self): @@ -627,7 +670,9 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == str(qml.QubitUnitary([[0, 1], [1, 0]], wires=0)) + assert nodes["node1"]["label"] == get_label( + qml.QubitUnitary([[0, 1], [1, 0]], wires=0) + ) @pytest.mark.unit def test_multi_rz_op(self): @@ -650,7 +695,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == str(qml.MultiRZ(0.5, wires=[0])) + assert nodes["node1"]["label"] == get_label(qml.MultiRZ(0.5, wires=[0])) class TestCreateStaticMeasurementNodes: @@ -677,7 +722,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == str(qml.state()) + assert nodes["node1"]["label"] == get_label(qml.state()) @pytest.mark.unit @pytest.mark.parametrize("meas_fn", [qml.expval, qml.var]) @@ -701,7 +746,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == str(meas_fn(qml.Z(0))) + assert nodes["node1"]["label"] == get_label(meas_fn(qml.Z(0))) @pytest.mark.unit @pytest.mark.parametrize( @@ -731,7 +776,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == str(op) + assert nodes["node1"]["label"] == get_label(op) @pytest.mark.unit @pytest.mark.parametrize( @@ -762,7 +807,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == str(op) + assert nodes["node1"]["label"] == get_label(op) @pytest.mark.unit def test_projective_measurement_op(self): From a25f7b40e702d2cf7249af462241e4fabcfc3d32 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 16:19:02 -0500 Subject: [PATCH 72/75] format --- .../visualization/test_construct_circuit_dag.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 48868a1d26..ed4e06f2e0 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -22,6 +22,10 @@ # pylint: disable=wrong-import-position # This import needs to be after pytest in order to prevent ImportErrors import pennylane as qml +from xdsl.dialects import test +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir.core import Block, Region + from catalyst import measure from catalyst.python_interface.conversion import xdsl_from_qjit from catalyst.python_interface.visualization.construct_circuit_dag import ( @@ -30,9 +34,6 @@ ) from catalyst.python_interface.visualization.dag_builder import DAGBuilder from catalyst.utils.exceptions import CompileError -from xdsl.dialects import test -from xdsl.dialects.builtin import ModuleOp -from xdsl.ir.core import Block, Region class FakeDAGBuilder(DAGBuilder): @@ -670,9 +671,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == get_label( - qml.QubitUnitary([[0, 1], [1, 0]], wires=0) - ) + assert nodes["node1"]["label"] == get_label(qml.QubitUnitary([[0, 1], [1, 0]], wires=0)) @pytest.mark.unit def test_multi_rz_op(self): From 02d4a55f63f80aada3f208fb5f684dfa2cc48a15 Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 16:42:11 -0500 Subject: [PATCH 73/75] fix tests --- .../visualization/test_construct_circuit_dag.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index ed4e06f2e0..12babe548d 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -563,9 +563,9 @@ def test_standard_operator(self, op): assert get_label(op) == f" {op.name}| {wires_str}" - def test_global_phase_operator(self, op): + def test_global_phase_operator(self): """Tests against a GlobalPhase operator instance.""" - assert get_label(qml.GlobalPhase(0.5)) == f" {op.name}| all" + assert get_label(qml.GlobalPhase(0.5)) == f" GlobalPhase| all" @pytest.mark.unit @pytest.mark.parametrize( @@ -648,7 +648,8 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == get_label(op) + # Compiler throws out the wires and they get converted to wires=[] no matter what + assert nodes["node1"]["label"] == get_label(qml.GlobalPhase(0.5)) @pytest.mark.unit def test_qubit_unitary_op(self): From 7237a3fc9522ec953ce8a5e7bf0d72c283add37c Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 18:05:33 -0500 Subject: [PATCH 74/75] move mid measure to an operator --- .../visualization/construct_circuit_dag.py | 34 +++++++------- .../test_construct_circuit_dag.py | 44 +++++++++---------- 2 files changed, 40 insertions(+), 38 deletions(-) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index 09dcee720f..2311b5e01d 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -116,6 +116,24 @@ def _unitary( ) self._node_uid_counter += 1 + @_visit_operation.register + def _projective_measure_op(self, op: quantum.MeasureOp) -> None: + """Handler for the single-qubit projective measurement operation.""" + + # Create PennyLane instance + meas = xdsl_to_qml_measurement(op) + + # Add node to current cluster + node_uid = f"node{self._node_uid_counter}" + self.dag_builder.add_node( + uid=node_uid, + label=get_label(meas), + cluster_uid=self._cluster_uid_stack[-1], + # NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record) + shape="record", + ) + self._node_uid_counter += 1 + # ===================== # QUANTUM MEASUREMENTS # ===================== @@ -186,22 +204,6 @@ def _visit_sample_and_probs_ops( ) self._node_uid_counter += 1 - @_visit_operation.register - def _projective_measure_op(self, op: quantum.MeasureOp) -> None: - """Handler for the single-qubit projective measurement operation.""" - - # Create PennyLane instance - meas = xdsl_to_qml_measurement(op) - - # Add node to current cluster - node_uid = f"node{self._node_uid_counter}" - self.dag_builder.add_node( - uid=node_uid, - label=get_label(meas), - cluster_uid=self._cluster_uid_stack[-1], - ) - self._node_uid_counter += 1 - # ============= # CONTROL FLOW # ============= diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index 12babe548d..b6c77fb3d7 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -697,6 +697,28 @@ def my_circuit(): assert nodes["node1"]["label"] == get_label(qml.MultiRZ(0.5, wires=[0])) + @pytest.mark.unit + def test_projective_measurement_op(self): + """Test that projective measurements can be captured as nodes.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_circuit(): + measure(0) + + module = my_circuit() + + # Construct DAG + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + nodes = utility.dag_builder.nodes + assert len(nodes) == 2 # Device node + operator + + assert nodes["node1"]["label"] == f" MidMeasure| [0]" + class TestCreateStaticMeasurementNodes: """Tests that measurements with static parameters can be created and visualized as nodes.""" @@ -808,25 +830,3 @@ def my_circuit(): assert len(nodes) == 2 # Device node + operator assert nodes["node1"]["label"] == get_label(op) - - @pytest.mark.unit - def test_projective_measurement_op(self): - """Test that projective measurements can be captured as nodes.""" - dev = qml.device("null.qubit", wires=1) - - @xdsl_from_qjit - @qml.qjit(autograph=True, target="mlir") - @qml.qnode(dev) - def my_circuit(): - measure(0) - - module = my_circuit() - - # Construct DAG - utility = ConstructCircuitDAG(FakeDAGBuilder()) - utility.construct(module) - - nodes = utility.dag_builder.nodes - assert len(nodes) == 2 # Device node + operator - - assert "MidMeasure" in nodes["node1"]["label"] From 6410aaed01bf9ed3ca48ae1622dcbdf7b727305a Mon Sep 17 00:00:00 2001 From: andrijapau Date: Fri, 5 Dec 2025 19:10:23 -0500 Subject: [PATCH 75/75] fix test --- .../visualization/test_construct_circuit_dag.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index b6c77fb3d7..0e90d66fdb 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -717,7 +717,7 @@ def my_circuit(): nodes = utility.dag_builder.nodes assert len(nodes) == 2 # Device node + operator - assert nodes["node1"]["label"] == f" MidMeasure| [0]" + assert nodes["node1"]["label"] == f" MidMeasureMP| [0]" class TestCreateStaticMeasurementNodes: