From 29aa5573ea8f891ed6518d3a7a35e79f14185a22 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Tue, 23 Sep 2025 16:36:58 -0400 Subject: [PATCH 1/8] Quantum execution for structured qnode repr is added to the Catalyst quantum dialect --- mlir/include/Quantum/IR/QuantumOps.td | 84 +++++++++++++++++++++++++++ 1 file changed, 84 insertions(+) diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index 89498b185a..97f2544d09 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -17,6 +17,7 @@ include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" +include "mlir/IR/SymbolInterfaces.td" include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" include "mlir/Interfaces/ControlFlowInterfaces.td" @@ -1274,5 +1275,88 @@ def StateOp : Measurement_Op<"state", [AttrSizedOperandSegments]> { let hasVerifier = 1; } +// ----- + +def ExecutionOp : Quantum_Op<"execution", [ + NoMemoryEffect, + SingleBlockImplicitTerminator<"ExecYieldOp">, + Symbol +]> { + let summary = "Define a quantum execution as a top-level symbol"; + let description = [{ + The `quantum.execution` operation defines a quantum program execution with four + distinct phases as a top-level symbol that can be called by `quantum.call_execution`. + It acts like a function definition that can be invoked multiple times. + + The four phases are: + 1. **Init region**: Device initialization and quantum register allocation + 2. **Evolution region**: Quantum state evolution + 3. **Measurement region**: Observables and measurements + 4. **Teardown region**: Resource cleanup + + Each region can accept unlimited operands through block arguments and pass data + to subsequent regions via `quantum.exec_yield` operations. + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + TypeAttr:$function_type + ); + + let results = (outs); + + let regions = (region + SizedRegion<1>:$init_region, + SizedRegion<1>:$state_evolution_region, + SizedRegion<1>:$measurement_region, + SizedRegion<1>:$teardown_region + ); + + let assemblyFormat = [{ + `(` `)` attr-dict + `(` `init` $init_region `,` `evolution` $state_evolution_region `,` `measurement` $measurement_region `,` `teardown` $teardown_region `)` + }]; +} + +def CallExecutionOp : Quantum_Op<"call_execution", [NoMemoryEffect]> { + let summary = "Call a quantum execution symbol"; + let description = [{ + The `quantum.call_execution` operation invokes a previously defined `quantum.execution` + symbol by name, passing the specified arguments and returning the computed results. + + This operation actually performs the quantum computation defined by the + execution symbol, similar to how func.call invokes func.func. + }]; + + let arguments = (ins + FlatSymbolRefAttr:$callee, + Variadic:$operands + ); + + let results = (outs + Variadic:$results + ); + + let assemblyFormat = [{ + `(` $operands `)` attr-dict `:` functional-type($operands, $results) + }]; +} + +def ExecYieldOp : Quantum_Op<"exec_yield", [Pure, ReturnLike, Terminator, ParentOneOf<["ExecutionOp"]>]> { + let summary = "Return results from quantum execution regions"; + + let arguments = (ins + Variadic:$retvals + ); + + let assemblyFormat = [{ + attr-dict ($retvals ^ `:` type($retvals))? + }]; + + let builders = [ + OpBuilder<(ins), [{ /* nothing to do */ }]> + ]; +} + #endif // QUANTUM_OPS From 3d8529999e1fb9affb29687d88d2e9476ed21498 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 25 Sep 2025 09:41:34 -0400 Subject: [PATCH 2/8] Add outline state evolution xDSL pass --- .../passes/outline_state_evolution.py | 442 ++++++++++++++++++ 1 file changed, 442 insertions(+) create mode 100644 frontend/catalyst/passes/outline_state_evolution.py diff --git a/frontend/catalyst/passes/outline_state_evolution.py b/frontend/catalyst/passes/outline_state_evolution.py new file mode 100644 index 0000000000..bc55659316 --- /dev/null +++ b/frontend/catalyst/passes/outline_state_evolution.py @@ -0,0 +1,442 @@ +# State Evolution Outlining Implementation +from dataclasses import dataclass, field +from itertools import chain +from typing import Type, TypeVar + +import logging +import numpy as np +import pennylane as qml + +from catalyst.ftqc import mbqc_pipeline +from catalyst.passes.xdsl_plugin import getXDSLPluginAbsolutePath +from pennylane.compiler.python_compiler.transforms.convert_to_mbqc_formalism import ( + convert_to_mbqc_formalism_pass, +) +from pennylane.compiler.python_compiler.transforms.decompose_graph_state import ( + decompose_graph_state_pass, +) + +from pennylane.compiler.python_compiler import compiler_transform +from pennylane.compiler.python_compiler.dialects import quantum + +from xdsl.context import Context +from xdsl.dialects import builtin, func +from xdsl.ir import Operation, SSAValue +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import PatternRewriter, RewritePattern, op_type_rewrite_pattern +from xdsl.rewriter import InsertPoint + +T = TypeVar("T") + +logger = logging.getLogger(__name__) +logger.disabled = True + +if not logger.handlers: + handler = logging.StreamHandler() + handler.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s") + handler.setFormatter(formatter) + logger.addHandler(handler) + + +def get_parent_of_type(op: Operation, kind: Type[T]) -> T | None: + """Walk up the parent tree until an op of the specified type is found.""" + while (op := op.parent_op()) and not isinstance(op, kind): + pass + return op + + +class OutlineStateEvolution(RewritePattern): + """Outline state evolution regions in a quantum function.""" + + def __init__(self): + self.module: builtin.ModuleOp = None + self.original_func_op: func.FuncOp = None + self.state_evolution_segment: StateEvolutionSegment = None + self.alloc_op: quantum.AllocOp = None + self.remaining_ops: list[Operation] = None + + @op_type_rewrite_pattern + def match_and_rewrite(self, func_op: func.FuncOp, rewriter: PatternRewriter): + """Transform a quantum function (qnode) to outline state evolution regions.""" + self.original_func_op = func_op + + if "qnode" not in func_op.attributes: + return + + self.module = get_parent_of_type(func_op, builtin.ModuleOp) + assert self.module is not None, "got orphaned qnode function" + + # Simplify the quantum I/O to use only registers at boundaries + self.simplify_quantum_io(func_op, rewriter) + + # Create a new function for the state evolution region + self.create_state_evolution_function(rewriter) + + # Replace the original region with a call to the state evolution function + self.finalize_transformation(rewriter) + + def get_idx(self, op: Operation) -> int | None: + """Get the index of the operation.""" + return ( + op.idx + if hasattr(op, "idx") and op.idx + else (op.idx_attr if hasattr(op, "idx_attr") else None) + ) + + def simplify_quantum_io(self, func_op: func.FuncOp, rewriter: PatternRewriter) -> func.FuncOp: + """Simplify quantum I/O to use only registers at segment boundaries. + + This ensures that state evolution regions only take registers as input/output, + not individual qubits. + """ + current_reg = None + qubit_to_reg_idx = {} + + for op in func_op.body.ops: + match op: + case quantum.AllocOp(): + current_reg = op.qreg + case quantum.ExtractOp(): + # Update register mapping + extract_idx = self.get_idx(op) + qubit_to_reg_idx[op.qubit] = extract_idx + op.operands = (current_reg, extract_idx) + case quantum.MeasureOp(): + qubit_to_reg_idx[op.out_qubit] = qubit_to_reg_idx[op.in_qubit] + del qubit_to_reg_idx[op.in_qubit] + case quantum.CustomOp(): + # Handle quantum gate operations + for i, qb in enumerate(chain(op.in_qubits, op.in_ctrl_qubits)): + qubit_to_reg_idx[op.out_qubits[i]] = i + qubit_to_reg_idx[op.results[i]] = qubit_to_reg_idx[qb] + del qubit_to_reg_idx[qb] + case quantum.InsertOp(): + assert qubit_to_reg_idx[op.qubit] is op.idx_attr if op.idx_attr else True + del qubit_to_reg_idx[op.qubit] + # update register since it might have changed + op.operands = (current_reg, op.idx, op.qubit) + current_reg = op.out_qreg + + case _ if isinstance( + op, + ( + quantum.ComputationalBasisOp, + quantum.NamedObsOp, + ), + ): + insert_ops = set() + + # create a register boundary before the terminal operation + rewriter.insertion_point = InsertPoint.before(op) + for qb, idx in qubit_to_reg_idx.items(): + insert_op = quantum.InsertOp(current_reg, idx, qb) + rewriter.insert(insert_op) + insert_ops.add(insert_op) + current_reg = insert_op.out_qreg + + list(insert_ops)[-1].attributes["terminal_boundary"] = builtin.UnitAttr() + + # extract ops + rewriter.insertion_point = InsertPoint.before(op) + for qb, idx in list(qubit_to_reg_idx.items()): + extract_op = quantum.ExtractOp(current_reg, idx) + rewriter.insert(extract_op) + qb.replace_by_if( + extract_op.qubit, lambda use: use.operation not in insert_ops + ) + qubit_to_reg_idx[extract_op.qubit] = idx + del qubit_to_reg_idx[qb] + + case _: + # Handle other operations that might has qreg result + if reg := next( + (reg for reg in op.results if isinstance(reg.type, quantum.QuregType)), None + ): + current_reg = reg + + def create_state_evolution_function(self, rewriter: PatternRewriter): + """Create a new function for the state evolution region using clone approach.""" + + alloc_op, terminal_boundary_op = self.find_evolution_range() + if not alloc_op or not terminal_boundary_op: + raise ValueError("Could not find alloc_op or terminal_boundary_op") + + if alloc_op.parent_block() != terminal_boundary_op.parent_block(): + raise ValueError("alloc_op and terminal_boundary_op are not in the same block") + + # collect operation from alloc_op to terminal_boundary_op + ops_to_clone = self.collect_operations_in_range(alloc_op, terminal_boundary_op) + + # analyze missing values for ops + missing_inputs = self.analyze_missing_values_for_ops(ops_to_clone) + + # analyze required outputs for ops + required_outputs = self.analyze_required_outputs(ops_to_clone, terminal_boundary_op) + + register_inputs = [] + other_inputs = [] + for val in missing_inputs: + if isinstance(val.type, quantum.QuregType): + register_inputs.append(val) + else: + other_inputs.append(val) + + register_outputs = [] + other_outputs = [] + for val in required_outputs: + if isinstance(val.type, quantum.QuregType): + register_outputs.append(val) + else: + other_outputs.append(val) + + ordered_inputs = register_inputs + other_inputs + ordered_outputs = register_outputs + other_outputs + + input_types = [val.type for val in ordered_inputs] + output_types = [val.type for val in ordered_outputs] + fun_type = builtin.FunctionType.from_lists(input_types, output_types) + + state_evolution_func = func.FuncOp("state_evolution", fun_type) + rewriter.insert_op(state_evolution_func, InsertPoint.at_end(self.module.body.block)) + + block = state_evolution_func.regions[0].block + value_mapper = {} + for missing_val, block_arg in zip(ordered_inputs, block.args): + value_mapper[missing_val] = block_arg + + self.clone_operations_to_block(ops_to_clone, block, value_mapper) + self.add_return_statement(block, ordered_outputs, value_mapper) + + self.missing_inputs = ordered_inputs + self.required_outputs = ordered_outputs + self.alloc_op = alloc_op + self.terminal_boundary_op = terminal_boundary_op + self.state_evolution_func = state_evolution_func + print("create_state_evolution_function successfully") + print(self.module) + + def find_evolution_range(self): + """find alloc_op and terminal_boundary_op""" + alloc_op = None + terminal_boundary_op = None + + for op in self.original_func_op.body.walk(): + if isinstance(op, quantum.AllocOp): + alloc_op = op + elif hasattr(op, "attributes") and "terminal_boundary" in op.attributes: + terminal_boundary_op = op + + return alloc_op, terminal_boundary_op + + def collect_operations_in_range(self, begin_op, end_op): + """collect top-level operations in range, let op.clone() handle nesting""" + ops_to_clone = [] + + if begin_op.parent_block() != end_op.parent_block(): + raise ValueError("begin_op and end_op are not in the same block") + + block = begin_op.parent_block() + + # skip until begin_op + op_iter = iter(block.ops) + while (op := next(op_iter, None)) != begin_op: + pass + + # collect top-level operations until end_op + while (op := next(op_iter, None)) != end_op: + ops_to_clone.append(op) + + # collect the terminal_boundary_op + if op is not None: + ops_to_clone.append(op) + + return ops_to_clone + + def analyze_missing_values_for_ops(self, ops: list[Operation]) -> list[SSAValue]: + """get missing values for ops + Given a list of operations, return the values that are missing from the operations. + """ + ops_defined_values = set() + + ops_walk = list(chain(*[op.walk() for op in ops])) + + ops_defined_values = set() + all_operands = set() + + for nested_op in ops_walk: + ops_defined_values.update(nested_op.results) + all_operands.update(nested_op.operands) + + return list(all_operands - ops_defined_values) + + def analyze_required_outputs( + self, ops: list[Operation], terminal_op: Operation + ) -> list[SSAValue]: + """get required outputs for ops + Given a list of operations and a terminal operation, return the values that are + required by the operations after the terminal operation. + Noted: It's only consdider the values that are defined in the operations and required by + the operations after the terminal operation! + """ + ops_walk = list(chain(*[op.walk() for op in ops])) + + ops_defined_values = set() + + for nested_op in ops_walk: + ops_defined_values.update(nested_op.results) + + required_outputs = set() + found_terminal = False + for op in self.original_func_op.body.walk(): + if op == terminal_op: + found_terminal = True + continue + + if found_terminal: + for operand in op.operands: + if operand in ops_defined_values: + required_outputs.add(operand) + + return list(required_outputs) + + def add_return_statement(self, target_block, required_outputs, value_mapper): + """add return statement to function""" + return_values = [] + for output_val in required_outputs: + if output_val not in value_mapper: + raise ValueError(f"output_val {output_val} not in value_mapper") + return_values.append(value_mapper[output_val]) + + return_op = func.ReturnOp(*return_values) + target_block.add_op(return_op) + + def clone_operations_to_block(self, ops_to_clone, target_block, value_mapper): + """Clone operations to target block, use value_mapper to update references""" + for op in ops_to_clone: + cloned_op = op.clone(value_mapper) + target_block.add_op(cloned_op) + + self.update_value_mapper_recursively(op, cloned_op, value_mapper) + + def update_value_mapper_recursively(self, orig_op, cloned_op, value_mapper): + """update value_mapper for all operations in operation""" + for orig_result, new_result in zip(orig_op.results, cloned_op.results): + value_mapper[orig_result] = new_result + + for orig_region, cloned_region in zip(orig_op.regions, cloned_op.regions): + self.update_region_value_mapper(orig_region, cloned_region, value_mapper) + + def update_region_value_mapper(self, orig_region, cloned_region, value_mapper): + """update value_mapper for all operations in region""" + for orig_block, cloned_block in zip(orig_region.blocks, cloned_region.blocks): + for orig_arg, cloned_arg in zip(orig_block.args, cloned_block.args): + value_mapper[orig_arg] = cloned_arg + + for orig_nested_op, cloned_nested_op in zip(orig_block.ops, cloned_block.ops): + self.update_value_mapper_recursively(orig_nested_op, cloned_nested_op, value_mapper) + + def finalize_transformation(self, rewriter: PatternRewriter): + """Replace the original function with a call to the state evolution function.""" + original_block = self.original_func_op.body.block + ops_list = list(original_block.ops) + + begin_idx = None + end_idx = None + for i, op in enumerate(ops_list): + if op == self.alloc_op: + begin_idx = i + 1 + elif op == self.terminal_boundary_op: + end_idx = i + 1 + + assert begin_idx is not None, "alloc_op not found in original function" + assert end_idx is not None, "terminal_boundary_op not found in original function" + assert begin_idx <= end_idx, "alloc_op should come before terminal_boundary_op" + + pre_ops = ops_list[:begin_idx] + post_ops = ops_list[end_idx:] + + call_args = list(self.missing_inputs) + result_types = [val.type for val in self.required_outputs] + + call_op = func.CallOp(self.state_evolution_func.sym_name.data, call_args, result_types) + + call_result_mapper = {} + for i, required_output in enumerate(self.required_outputs): + if i < len(call_op.results): + call_result_mapper[required_output] = call_op.results[i] + + # TODO: I just removed all ops and add them again to update with value_mapper. + # It's not efficient, just because it's easy to implement. Should using replace use method + # instead. + for op in reversed(ops_list): + op.detach() + + new_ops = [] + for op in pre_ops: + new_ops.append(op) + + new_ops.append(call_op) + + value_mapper = call_result_mapper.copy() + + for i, op in enumerate(post_ops): + cloned_op = op.clone(value_mapper) + + for orig_result, new_result in zip(op.results, cloned_op.results): + value_mapper[orig_result] = new_result + + new_ops.append(cloned_op) + + for op in new_ops: + original_block.add_op(op) + + +@compiler_transform +class OutlineStateEvolutionPass(ModulePass): + name = "outline-state-evolution" + + def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: + print("before outline-state-evolution") + print(op) + + self.apply_on_qnode(op, OutlineStateEvolution()) + + print("after outline-state-evolution") + print(op) + + def apply_on_qnode(self, module: builtin.ModuleOp, pattern: RewritePattern): + """Apply given pattern once to the QNode function in this module.""" + rewriter = PatternRewriter(module) + qnode = None + for op in module.ops: + if isinstance(op, func.FuncOp) and "qnode" in op.attributes: + qnode = op + break + assert qnode is not None, "expected QNode in module" + pattern.match_and_rewrite(qnode, rewriter) + + +if __name__ == "__main__": + qml.capture.enable() + + @qml.qjit( + target="mlir", + pass_plugins=[getXDSLPluginAbsolutePath()], + pipelines=mbqc_pipeline(), + autograph=True, + keep_intermediate="pass", + ) + @decompose_graph_state_pass + @convert_to_mbqc_formalism_pass + @OutlineStateEvolutionPass + @qml.set_shots(10) + @qml.qnode(qml.device("null.qubit", wires=2)) + def main(p: float): + qml.Hadamard(0) + qml.Hadamard(1) + qml.measure(0) + qml.RX(p, wires=0) + return qml.expval(qml.X(0)) + + main(np.pi / 2) From 6ed6ae908daa5124e214991e9c8da7294d304717 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 25 Sep 2025 09:53:42 -0400 Subject: [PATCH 3/8] revert --- mlir/include/Quantum/IR/QuantumOps.td | 84 --------------------------- 1 file changed, 84 deletions(-) diff --git a/mlir/include/Quantum/IR/QuantumOps.td b/mlir/include/Quantum/IR/QuantumOps.td index 97f2544d09..89498b185a 100644 --- a/mlir/include/Quantum/IR/QuantumOps.td +++ b/mlir/include/Quantum/IR/QuantumOps.td @@ -17,7 +17,6 @@ include "mlir/IR/EnumAttr.td" include "mlir/IR/OpBase.td" -include "mlir/IR/SymbolInterfaces.td" include "mlir/Dialect/Bufferization/IR/AllocationOpInterface.td" include "mlir/Interfaces/ControlFlowInterfaces.td" @@ -1275,88 +1274,5 @@ def StateOp : Measurement_Op<"state", [AttrSizedOperandSegments]> { let hasVerifier = 1; } -// ----- - -def ExecutionOp : Quantum_Op<"execution", [ - NoMemoryEffect, - SingleBlockImplicitTerminator<"ExecYieldOp">, - Symbol -]> { - let summary = "Define a quantum execution as a top-level symbol"; - let description = [{ - The `quantum.execution` operation defines a quantum program execution with four - distinct phases as a top-level symbol that can be called by `quantum.call_execution`. - It acts like a function definition that can be invoked multiple times. - - The four phases are: - 1. **Init region**: Device initialization and quantum register allocation - 2. **Evolution region**: Quantum state evolution - 3. **Measurement region**: Observables and measurements - 4. **Teardown region**: Resource cleanup - - Each region can accept unlimited operands through block arguments and pass data - to subsequent regions via `quantum.exec_yield` operations. - }]; - - let arguments = (ins - SymbolNameAttr:$sym_name, - TypeAttr:$function_type - ); - - let results = (outs); - - let regions = (region - SizedRegion<1>:$init_region, - SizedRegion<1>:$state_evolution_region, - SizedRegion<1>:$measurement_region, - SizedRegion<1>:$teardown_region - ); - - let assemblyFormat = [{ - `(` `)` attr-dict - `(` `init` $init_region `,` `evolution` $state_evolution_region `,` `measurement` $measurement_region `,` `teardown` $teardown_region `)` - }]; -} - -def CallExecutionOp : Quantum_Op<"call_execution", [NoMemoryEffect]> { - let summary = "Call a quantum execution symbol"; - let description = [{ - The `quantum.call_execution` operation invokes a previously defined `quantum.execution` - symbol by name, passing the specified arguments and returning the computed results. - - This operation actually performs the quantum computation defined by the - execution symbol, similar to how func.call invokes func.func. - }]; - - let arguments = (ins - FlatSymbolRefAttr:$callee, - Variadic:$operands - ); - - let results = (outs - Variadic:$results - ); - - let assemblyFormat = [{ - `(` $operands `)` attr-dict `:` functional-type($operands, $results) - }]; -} - -def ExecYieldOp : Quantum_Op<"exec_yield", [Pure, ReturnLike, Terminator, ParentOneOf<["ExecutionOp"]>]> { - let summary = "Return results from quantum execution regions"; - - let arguments = (ins - Variadic:$retvals - ); - - let assemblyFormat = [{ - attr-dict ($retvals ^ `:` type($retvals))? - }]; - - let builders = [ - OpBuilder<(ins), [{ /* nothing to do */ }]> - ]; -} - #endif // QUANTUM_OPS From bd8f5e959922d006e87665bf6803defdb64477ed Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 25 Sep 2025 10:04:21 -0400 Subject: [PATCH 4/8] change function name --- frontend/catalyst/passes/outline_state_evolution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/passes/outline_state_evolution.py b/frontend/catalyst/passes/outline_state_evolution.py index bc55659316..a9320dbdf2 100644 --- a/frontend/catalyst/passes/outline_state_evolution.py +++ b/frontend/catalyst/passes/outline_state_evolution.py @@ -197,7 +197,9 @@ def create_state_evolution_function(self, rewriter: PatternRewriter): output_types = [val.type for val in ordered_outputs] fun_type = builtin.FunctionType.from_lists(input_types, output_types) - state_evolution_func = func.FuncOp("state_evolution", fun_type) + state_evolution_func = func.FuncOp( + self.original_func_op.sym_name.data + ".state_evolution", fun_type + ) rewriter.insert_op(state_evolution_func, InsertPoint.at_end(self.module.body.block)) block = state_evolution_func.regions[0].block From 9a71793330e025a601253dc6fb630718a0d92ce5 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Fri, 26 Sep 2025 09:42:39 -0400 Subject: [PATCH 5/8] update --- .../catalyst/passes/outline_state_evolution.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/frontend/catalyst/passes/outline_state_evolution.py b/frontend/catalyst/passes/outline_state_evolution.py index a9320dbdf2..d614169e35 100644 --- a/frontend/catalyst/passes/outline_state_evolution.py +++ b/frontend/catalyst/passes/outline_state_evolution.py @@ -52,11 +52,14 @@ class OutlineStateEvolution(RewritePattern): def __init__(self): self.module: builtin.ModuleOp = None self.original_func_op: func.FuncOp = None - self.state_evolution_segment: StateEvolutionSegment = None self.alloc_op: quantum.AllocOp = None - self.remaining_ops: list[Operation] = None - @op_type_rewrite_pattern + # state evolution region + self.missing_inputs: list[SSAValue] = None + self.required_outputs: list[SSAValue] = None + self.terminal_boundary_op: Operation = None + self.state_evolution_func: func.FuncOp = None + def match_and_rewrite(self, func_op: func.FuncOp, rewriter: PatternRewriter): """Transform a quantum function (qnode) to outline state evolution regions.""" self.original_func_op = func_op @@ -437,8 +440,11 @@ def apply_on_qnode(self, module: builtin.ModuleOp, pattern: RewritePattern): def main(p: float): qml.Hadamard(0) qml.Hadamard(1) - qml.measure(0) - qml.RX(p, wires=0) + m = qml.measure(0) + @qml.cond(m) + def true_fn(): + qml.RX(p, wires=0) + true_fn() return qml.expval(qml.X(0)) main(np.pi / 2) From c9031ee39cd5dad302cff5ef133132324f2cc6ae Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Fri, 26 Sep 2025 09:42:59 -0400 Subject: [PATCH 6/8] formatting --- frontend/catalyst/passes/outline_state_evolution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/passes/outline_state_evolution.py b/frontend/catalyst/passes/outline_state_evolution.py index d614169e35..1f16e992e7 100644 --- a/frontend/catalyst/passes/outline_state_evolution.py +++ b/frontend/catalyst/passes/outline_state_evolution.py @@ -23,7 +23,7 @@ from xdsl.dialects import builtin, func from xdsl.ir import Operation, SSAValue from xdsl.passes import ModulePass -from xdsl.pattern_rewriter import PatternRewriter, RewritePattern, op_type_rewrite_pattern +from xdsl.pattern_rewriter import PatternRewriter, RewritePattern from xdsl.rewriter import InsertPoint T = TypeVar("T") @@ -441,9 +441,11 @@ def main(p: float): qml.Hadamard(0) qml.Hadamard(1) m = qml.measure(0) + @qml.cond(m) def true_fn(): qml.RX(p, wires=0) + true_fn() return qml.expval(qml.X(0)) From 75d2a4ae5af9b2e7512da970c3117ada71496aed Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 2 Oct 2025 09:36:45 -0400 Subject: [PATCH 7/8] Fix for loop --- frontend/catalyst/passes/outline_state_evolution.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/passes/outline_state_evolution.py b/frontend/catalyst/passes/outline_state_evolution.py index 1f16e992e7..9c0a62dc83 100644 --- a/frontend/catalyst/passes/outline_state_evolution.py +++ b/frontend/catalyst/passes/outline_state_evolution.py @@ -262,8 +262,6 @@ def analyze_missing_values_for_ops(self, ops: list[Operation]) -> list[SSAValue] """get missing values for ops Given a list of operations, return the values that are missing from the operations. """ - ops_defined_values = set() - ops_walk = list(chain(*[op.walk() for op in ops])) ops_defined_values = set() @@ -273,7 +271,15 @@ def analyze_missing_values_for_ops(self, ops: list[Operation]) -> list[SSAValue] ops_defined_values.update(nested_op.results) all_operands.update(nested_op.operands) - return list(all_operands - ops_defined_values) + if hasattr(nested_op, "regions") and nested_op.regions: + for region in nested_op.regions: + for block in region.blocks: + ops_defined_values.update(block.args) + + missing_values = list(all_operands - ops_defined_values) + missing_values = [v for v in missing_values if v is not None] + + return missing_values def analyze_required_outputs( self, ops: list[Operation], terminal_op: Operation From 9d1e47048a7f83ab3d4567a178c33e14aa789185 Mon Sep 17 00:00:00 2001 From: Hong-Sheng Zheng Date: Thu, 2 Oct 2025 17:22:28 -0400 Subject: [PATCH 8/8] add obs --- frontend/catalyst/passes/outline_state_evolution.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/passes/outline_state_evolution.py b/frontend/catalyst/passes/outline_state_evolution.py index 9c0a62dc83..1db08007a9 100644 --- a/frontend/catalyst/passes/outline_state_evolution.py +++ b/frontend/catalyst/passes/outline_state_evolution.py @@ -95,6 +95,7 @@ def simplify_quantum_io(self, func_op: func.FuncOp, rewriter: PatternRewriter) - """ current_reg = None qubit_to_reg_idx = {} + terminal_boundary_op = None for op in func_op.body.ops: match op: @@ -126,8 +127,10 @@ def simplify_quantum_io(self, func_op: func.FuncOp, rewriter: PatternRewriter) - ( quantum.ComputationalBasisOp, quantum.NamedObsOp, + quantum.HamiltonianOp, + quantum.TensorOp, ), - ): + ) and not terminal_boundary_op: insert_ops = set() # create a register boundary before the terminal operation @@ -139,6 +142,7 @@ def simplify_quantum_io(self, func_op: func.FuncOp, rewriter: PatternRewriter) - current_reg = insert_op.out_qreg list(insert_ops)[-1].attributes["terminal_boundary"] = builtin.UnitAttr() + terminal_boundary_op = list(insert_ops)[-1] # extract ops rewriter.insertion_point = InsertPoint.before(op)