diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f9e0be4105..fcdb9f2ed5 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -8,6 +8,7 @@ [(#2214)](https://github.com/PennyLaneAI/catalyst/pull/2214) [(#2246)](https://github.com/PennyLaneAI/catalyst/pull/2246) [(#2231)](https://github.com/PennyLaneAI/catalyst/pull/2231) + [(#2234)](https://github.com/PennyLaneAI/catalyst/pull/2234) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) diff --git a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py index ef13948b18..6a548b5d51 100644 --- a/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py +++ b/frontend/catalyst/python_interface/visualization/construct_circuit_dag.py @@ -16,10 +16,10 @@ from functools import singledispatchmethod -from xdsl.dialects import builtin, func -from xdsl.ir import Block, Operation, Region +from xdsl.dialects import builtin, func, scf +from xdsl.ir import Block, Operation, Region, SSAValue -from catalyst.python_interface.dialects import catalyst, quantum +from catalyst.python_interface.dialects import quantum from catalyst.python_interface.visualization.dag_builder import DAGBuilder @@ -85,6 +85,97 @@ def _visit_block(self, block: Block) -> None: for op in block.ops: self._visit_operation(op) + # ============= + # CONTROL FLOW + # ============= + + @_visit_operation.register + def _for_op(self, operation: scf.ForOp) -> None: + """Handle an xDSL ForOp operation.""" + + uid = f"cluster{self._cluster_uid_counter}" + self.dag_builder.add_cluster( + uid, + node_label="for loop", + label="", + cluster_uid=self._cluster_uid_stack[-1], + ) + self._cluster_uid_stack.append(uid) + self._cluster_uid_counter += 1 + + for region in operation.regions: + self._visit_region(region) + + self._cluster_uid_stack.pop() + + @_visit_operation.register + def _while_op(self, operation: scf.WhileOp) -> None: + """Handle an xDSL WhileOp operation.""" + uid = f"cluster{self._cluster_uid_counter}" + self.dag_builder.add_cluster( + uid, + node_label="while loop", + label="", + cluster_uid=self._cluster_uid_stack[-1], + ) + self._cluster_uid_stack.append(uid) + self._cluster_uid_counter += 1 + + for region in operation.regions: + self._visit_region(region) + + self._cluster_uid_stack.pop() + + @_visit_operation.register + def _if_op(self, operation: scf.IfOp): + """Handles the scf.IfOp operation.""" + flattened_if_op: list[tuple[SSAValue | None, Region]] = _flatten_if_op(operation) + + uid = f"cluster{self._cluster_uid_counter}" + self.dag_builder.add_cluster( + uid, + node_label="", + label="conditional", + labeljust="l", + cluster_uid=self._cluster_uid_stack[-1], + ) + self._cluster_uid_stack.append(uid) + self._cluster_uid_counter += 1 + + # Loop through each branch and visualize as a cluster + num_regions = len(flattened_if_op) + for i, (condition_ssa, region) in enumerate(flattened_if_op): + + def _get_conditional_branch_label(i): + if i == 0: + return "if" + elif i == num_regions - 1: + return "else" + else: + return "elif" + + uid = f"cluster{self._cluster_uid_counter}" + self.dag_builder.add_cluster( + uid, + node_label=_get_conditional_branch_label(i), + label="", + style="dashed", + penwidth=1, + cluster_uid=self._cluster_uid_stack[-1], + ) + self._cluster_uid_stack.append(uid) + self._cluster_uid_counter += 1 + + # Go recursively into the branch to process internals + self._visit_region(region) + + # Pop branch cluster after processing to ensure + # logical branches are treated as 'parallel' + self._cluster_uid_stack.pop() + + # Pop IfOp cluster before leaving this handler + self._cluster_uid_stack.pop() + # ============ # DEVICE NODE # ============ @@ -138,3 +229,28 @@ def _func_return(self, operation: func.ReturnOp) -> None: # If we hit a func.return operation we know we are leaving # the FuncOp's scope and so we can pop the ID off the stack. self._cluster_uid_stack.pop() + + +def _flatten_if_op(op: scf.IfOp) -> list[tuple[SSAValue | None, Region]]: + """Recursively flattens a nested IfOp (if/elif/else chains).""" + + condition_ssa: SSAValue = op.operands[0] + then_region, else_region = op.regions + + # Save condition SSA in case we want to visualize it eventually + flattened_op: list[tuple[SSAValue | None, Region]] = [(condition_ssa, then_region)] + + # Peak into else region to see if there's another IfOp + else_block: Block = else_region.block + # Completely relies on the structure that the second last operation + # will be an IfOp (seems to hold true) + if isinstance(else_block.ops.last.prev_op, scf.IfOp): + # Recursively flatten any IfOps found in said block + nested_flattened_op = _flatten_if_op(else_block.ops.last.prev_op) + flattened_op.extend(nested_flattened_op) + return flattened_op + + # No more nested IfOps, therefore append final region + # with no SSAValue + flattened_op.extend([(None, else_region)]) + return flattened_op diff --git a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py index d4cac366a2..82f50140fa 100644 --- a/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py +++ b/frontend/test/pytest/python_interface/visualization/test_construct_circuit_dag.py @@ -287,3 +287,256 @@ def my_workflow(): # Assert label is as expected assert graph_nodes["node1"]["label"] == "LightningSimulator" + + +class TestForOp: + """Tests that the for loop control flow can be visualized correctly.""" + + @pytest.mark.unit + def test_basic_example(self): + """Tests that the for loop cluster can be visualized correctly.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + for i in range(3): + qml.H(0) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + assert clusters["cluster2"]["node_label"] == "for loop" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + @pytest.mark.unit + def test_nested_loop(self): + """Tests that nested for loops are visualized correctly.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + for i in range(0, 5, 2): + for j in range(1, 6, 2): + qml.H(0) + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + assert clusters["cluster2"]["node_label"] == "for loop" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + assert clusters["cluster3"]["node_label"] == "for loop" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + + +class TestWhileOp: + """Tests that the while loop control flow can be visualized correctly.""" + + @pytest.mark.unit + def test_basic_example(self): + """Test that the while loop is visualized correctly.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + counter = 0 + while counter < 5: + qml.H(0) + counter += 1 + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + assert clusters["cluster2"]["node_label"] == "while loop" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + @pytest.mark.unit + def test_nested_loop(self): + """Tests that nested while loops are visualized correctly.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(): + outer_counter = 0 + inner_counter = 0 + while outer_counter < 5: + while inner_counter < 6: + qml.H(0) + inner_counter += 1 + outer_counter += 1 + + module = my_workflow() + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + assert clusters["cluster2"]["node_label"] == "while loop" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + assert clusters["cluster3"]["node_label"] == "while loop" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + + +class TestIfOp: + """Tests that the conditional control flow can be visualized correctly.""" + + @pytest.mark.unit + def test_basic_example(self): + """Test that the conditional operation is visualized correctly.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x): + if x == 2: + qml.X(0) + else: + qml.Y(0) + + args = (1,) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # Check conditional is a cluster within cluster1 (my_workflow) + assert clusters["cluster2"]["cluster_label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + # Check three clusters live within cluster2 (conditional) + assert clusters["cluster3"]["node_label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster4"]["node_label"] == "else" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2" + + @pytest.mark.unit + def test_if_elif_else_conditional(self): + """Test that the conditional operation is visualized correctly.""" + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x): + if x == 1: + qml.X(0) + elif x == 2: + qml.Y(0) + else: + qml.Z(0) + + args = (1,) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # Check conditional is a cluster within my_workflow + assert clusters["cluster2"]["cluster_label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + # Check three clusters live within conditional + assert clusters["cluster3"]["node_label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster4"]["node_label"] == "elif" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster2" + assert clusters["cluster5"]["node_label"] == "else" + assert clusters["cluster5"]["parent_cluster_uid"] == "cluster2" + + @pytest.mark.unit + def test_nested_conditionals(self): + """Tests that nested conditionals are visualized correctly.""" + + dev = qml.device("null.qubit", wires=1) + + @xdsl_from_qjit + @qml.qjit(autograph=True, target="mlir") + @qml.qnode(dev) + def my_workflow(x, y): + if x == 1: + if y == 2: + qml.H(0) + else: + qml.Z(0) + qml.X(0) + else: + qml.Z(0) + + args = (1, 2) + module = my_workflow(*args) + + utility = ConstructCircuitDAG(FakeDAGBuilder()) + utility.construct(module) + + clusters = utility.dag_builder.clusters + clusters = utility.dag_builder.clusters + + # cluster0 -> qjit + # cluster1 -> my_workflow + # cluster2 -> conditional (1) + # cluster3 -> if + # cluster4 -> conditional () + # cluster5 -> if + # cluster6 -> else + # cluster7 -> else + + # Check first conditional is a cluster within my_workflow + assert clusters["cluster2"]["cluster_label"] == "conditional" + assert clusters["cluster2"]["parent_cluster_uid"] == "cluster1" + + # Check 'if' cluster of first conditional has another conditional + assert clusters["cluster3"]["node_label"] == "if" + assert clusters["cluster3"]["parent_cluster_uid"] == "cluster2" + + # Second conditional + assert clusters["cluster4"]["cluster_label"] == "conditional" + assert clusters["cluster4"]["parent_cluster_uid"] == "cluster3" + # Check 'if' and 'else' in second conditional + assert clusters["cluster5"]["node_label"] == "if" + assert clusters["cluster5"]["parent_cluster_uid"] == "cluster4" + assert clusters["cluster6"]["node_label"] == "else" + assert clusters["cluster6"]["parent_cluster_uid"] == "cluster4" + + # Check nested if / else is within the first if cluster + assert clusters["cluster7"]["node_label"] == "else" + assert clusters["cluster7"]["parent_cluster_uid"] == "cluster2"