Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
25aa04b
cl
andrijapau Dec 2, 2025
d613c1b
Apply suggestion from @andrijapau
andrijapau Dec 2, 2025
73c8a0f
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 2, 2025
98c9175
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 2, 2025
3d80d1a
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 2, 2025
0cd2f7d
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 2, 2025
858f40e
add good static connectivity
andrijapau Dec 2, 2025
3d0af6a
rename
andrijapau Dec 2, 2025
046196f
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 2, 2025
08082d9
clean-up
andrijapau Dec 2, 2025
f23b550
refactor
andrijapau Dec 2, 2025
437c7d8
make sure data is not carried over between qnodes
andrijapau Dec 3, 2025
39d6ea6
format
andrijapau Dec 3, 2025
9c907c4
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 3, 2025
68c6ec0
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 3, 2025
f25dd2a
Apply suggestion from @andrijapau
andrijapau Dec 4, 2025
6d46b94
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 4, 2025
56c70c9
update logic
andrijapau Dec 4, 2025
a544727
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 4, 2025
336d2e5
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 4, 2025
d66ff68
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 4, 2025
a9125f5
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 4, 2025
aea60d0
add test skeeltons
andrijapau Dec 4, 2025
75bdd32
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 4, 2025
6d3c256
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 5, 2025
df222e2
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 5, 2025
e38737c
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 5, 2025
aa7dd72
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 5, 2025
4dddd74
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 5, 2025
1712a27
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 5, 2025
67759dd
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 5, 2025
4f0a8f4
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 5, 2025
5d9f21c
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 5, 2025
19b326a
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 5, 2025
a9a89a9
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 5, 2025
b58602d
format
andrijapau Dec 5, 2025
68c2488
whoops
andrijapau Dec 5, 2025
d28eb63
Merge branch 'feature/visualize-nodes' into feature/connect-nodes
andrijapau Dec 5, 2025
fc3ea2c
add connectivity to MCMs
andrijapau Dec 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
[(#2231)](https://github.com/PennyLaneAI/catalyst/pull/2231)
[(#2234)](https://github.com/PennyLaneAI/catalyst/pull/2234)
[(#2118)](https://github.com/PennyLaneAI/catalyst/pull/2218)
[(#2260)](https://github.com/PennyLaneAI/catalyst/pull/2260)

* Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator.
[(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Contains the ConstructCircuitDAG tool for constructing a DAG from an xDSL module."""

from collections import defaultdict
from functools import singledispatch, singledispatchmethod

from pennylane.measurements import MeasurementProcess
Expand Down Expand Up @@ -50,6 +51,11 @@ def __init__(self, dag_builder: DAGBuilder) -> None:
# Keep track of nesting clusters using a stack
self._cluster_uid_stack: list[str] = []

# Create a map of wire to node uid
# Keys represent static (int) or dynamic wires (str)
# Values represent the set of all node uids that are on that wire.
self._wire_to_node_uids: dict[str | int, set[str]] = defaultdict(set)

# Use counter internally for UID
self._node_uid_counter: int = 0
self._cluster_uid_counter: int = 0
Expand All @@ -59,6 +65,7 @@ def _reset(self) -> None:
self._cluster_uid_stack: list[str] = []
self._node_uid_counter: int = 0
self._cluster_uid_counter: int = 0
self._wire_to_node_uids: dict[str | int, set[str]] = defaultdict(set)

def construct(self, module: builtin.ModuleOp) -> None:
"""Constructs the DAG from the module.
Expand Down Expand Up @@ -116,24 +123,42 @@ def _unitary(
)
self._node_uid_counter += 1

# Search through previous ops found on current wires and connect
prev_ops = set.union(*(self._wire_to_node_uids[wire] for wire in qml_op.wires))
for prev_op in prev_ops:
self.dag_builder.add_edge(prev_op, node_uid)

# Update affected wires to source from this node UID
for wire in qml_op.wires:
self._wire_to_node_uids[wire] = {node_uid}

@_visit_operation.register
def _projective_measure_op(self, op: quantum.MeasureOp) -> None:
"""Handler for the single-qubit projective measurement operation."""

# Create PennyLane instance
meas = xdsl_to_qml_measurement(op)
qml_op = xdsl_to_qml_measurement(op)

# Add node to current cluster
node_uid = f"node{self._node_uid_counter}"
self.dag_builder.add_node(
uid=node_uid,
label=get_label(meas),
label=get_label(qml_op),
cluster_uid=self._cluster_uid_stack[-1],
# NOTE: "record" allows us to use ports (https://graphviz.org/doc/info/shapes.html#record)
shape="record",
)
self._node_uid_counter += 1

# Search through previous ops found on current wires and connect
prev_ops = set.union(*(self._wire_to_node_uids[wire] for wire in qml_op.wires))
for prev_op in prev_ops:
self.dag_builder.add_edge(prev_op, node_uid)

# Update affected wires to source from this node UID
for wire in qml_op.wires:
self._wire_to_node_uids[wire] = {node_uid}

# =====================
# QUANTUM MEASUREMENTS
# =====================
Expand All @@ -156,6 +181,10 @@ def _state_op(self, op: quantum.StateOp) -> None:
)
self._node_uid_counter += 1

for seen_wire, seen_nodes in self._wire_to_node_uids.items():
for seen_node in seen_nodes:
self.dag_builder.add_edge(seen_node, node_uid, color="lightpink3")

@_visit_operation.register
def _statistical_measurement_ops(
self,
Expand All @@ -178,6 +207,10 @@ def _statistical_measurement_ops(
)
self._node_uid_counter += 1

for wire in meas.wires:
for seen_node in self._wire_to_node_uids[wire]:
self.dag_builder.add_edge(seen_node, node_uid, color="lightpink3")

@_visit_operation.register
def _visit_sample_and_probs_ops(
self,
Expand All @@ -204,6 +237,10 @@ def _visit_sample_and_probs_ops(
)
self._node_uid_counter += 1

for wire in meas.wires:
for seen_node in self._wire_to_node_uids[wire]:
self.dag_builder.add_edge(seen_node, node_uid, color="lightpink3")

# =============
# CONTROL FLOW
# =============
Expand Down Expand Up @@ -261,10 +298,15 @@ def _if_op(self, operation: scf.IfOp):
self._cluster_uid_stack.append(uid)
self._cluster_uid_counter += 1

# Save wires state before all of the branches
wire_map_before = self._wire_to_node_uids.copy()
region_wire_maps: list[dict[int | str, set[str]]] = []

# Loop through each branch and visualize as a cluster
num_regions = len(flattened_if_op)
for i, (condition_ssa, region) in enumerate(flattened_if_op):

# Visualize with a cluster
def _get_conditional_branch_label(i):
if i == 0:
return "if"
Expand All @@ -285,16 +327,42 @@ def _get_conditional_branch_label(i):
self._cluster_uid_stack.append(uid)
self._cluster_uid_counter += 1

# Make fresh wire map before going into region
self._wire_to_node_uids = wire_map_before.copy()

# Go recursively into the branch to process internals
self._visit_region(region)

# Update branch wire maps
if self._wire_to_node_uids != wire_map_before:
region_wire_maps.append(self._wire_to_node_uids)

# Pop branch cluster after processing to ensure
# logical branches are treated as 'parallel'
self._cluster_uid_stack.pop()

# Pop IfOp cluster before leaving this handler
self._cluster_uid_stack.pop()

# Check what wires were affected
affected_wires: set[str | int] = set(wire_map_before.keys())
for region_wire_map in region_wire_maps:
affected_wires.update(region_wire_map.keys())

# Update state to be the union of all branch wire maps
final_wire_map = defaultdict(set)
for wire in affected_wires:
all_nodes: set = set()
for region_wire_map in region_wire_maps:
if not wire in region_wire_map:
# IfOp region didn't apply anything on this wire
# so default to node before the IfOp
all_nodes.update(wire_map_before.get(wire, set()))
else:
all_nodes.update(region_wire_map.get(wire, set()))
final_wire_map[wire] = all_nodes
self._wire_to_node_uids = final_wire_map

# ============
# DEVICE NODE
# ============
Expand Down Expand Up @@ -349,6 +417,9 @@ def _func_return(self, operation: func.ReturnOp) -> None:
# the FuncOp's scope and so we can pop the ID off the stack.
self._cluster_uid_stack.pop()

# Clear seen wires as we are exiting a FuncOp (qnode)
self._wire_to_node_uids = defaultdict(set)


def _flatten_if_op(op: scf.IfOp) -> list[tuple[SSAValue | None, Region]]:
"""Recursively flattens a nested IfOp (if/elif/else chains)."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -830,3 +830,28 @@ def my_circuit():
assert len(nodes) == 2 # Device node + operator

assert nodes["node1"]["label"] == get_label(op)


class TestOperatorConnectivity:
"""Tests that operators are properly connected."""

@pytest.mark.unit
def test_static_connection_within_cluster(self):
"""Tests that connections can be made within the same cluster."""
pass

@pytest.mark.unit
def test_static_connection_through_clusters(self):
"""Tests that connections can be made through nested clusters."""
pass

@pytest.mark.unit
def test_static_connection_through_conditional(self):
"""Tests that connections through conditionals make sense."""
pass


class TestTerminalMeasurementConnectivity:
"""Test that terminal measurements connect properly."""

pass
Loading