From 892de30a82359996158374473368555ad0aabd9f Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Wed, 19 Nov 2025 12:02:50 -0500 Subject: [PATCH 01/10] [PROTOTYPE] Custom `PatternRewriter` for quantum transforms --- .../pass_api/cancel_inverses_prototype.py | 17 +- .../python_interface/pass_api/passes.py | 16 +- .../pass_api/pattern_rewriter.py | 503 ++++++++++++++++++ 3 files changed, 518 insertions(+), 18 deletions(-) create mode 100644 frontend/catalyst/python_interface/pass_api/pattern_rewriter.py diff --git a/frontend/catalyst/python_interface/pass_api/cancel_inverses_prototype.py b/frontend/catalyst/python_interface/pass_api/cancel_inverses_prototype.py index 1983fd384c..38cbee766e 100644 --- a/frontend/catalyst/python_interface/pass_api/cancel_inverses_prototype.py +++ b/frontend/catalyst/python_interface/pass_api/cancel_inverses_prototype.py @@ -20,6 +20,7 @@ from .compiler_transform import compiler_transform from .passes import PLModulePass +from .pattern_rewriter import PLPatternRewriter class CancelInverses(PLModulePass): @@ -58,7 +59,7 @@ def can_cancel(op: quantum.CustomOp, next_op: Operation) -> bool: @CancelInverses.rewrite_rule(quantum.CustomOp) -def rewrite_custom_op(self, op, rewriter): +def rewrite_custom_op(self, op: quantum.CustomOp, rewriter: PLPatternRewriter): """Rewrite rule for CustomOp.""" while isinstance(op, quantum.CustomOp) and op.gate_name.data in self.self_inverses: next_user = None @@ -71,31 +72,27 @@ def rewrite_custom_op(self, op, rewriter): if next_user is None: break - for q1, q2 in zip(op.in_qubits, next_user.out_qubits, strict=True): - rewriter.replace_all_uses_with(q2, q1) - for cq1, cq2 in zip(op.in_ctrl_qubits, next_user.out_ctrl_qubits, strict=True): - rewriter.replace_all_uses_with(cq2, cq1) - rewriter.erase_op(next_user) - rewriter.erase_op(op) + rewriter.erase_gate(next_user) + rewriter.erase_gate(op) op = op.in_qubits[0].owner # We can register more rewrite rules as needed. Here are some # dummy rewrite rules to illustrate: @CancelInverses.rewrite_rule(quantum.InsertOp) -def rewrite_insert_op(self, op, rewriter): +def rewrite_insert_op(self, op: quantum.InsertOp, rewriter: PLPatternRewriter): """Rewrite rule for InsertOp.""" return @CancelInverses.rewrite_rule(quantum.ExtractOp) -def rewrite_extract_op(self, op, rewriter): +def rewrite_extract_op(self, op: quantum.ExtractOp, rewriter: PLPatternRewriter): """Rewrite rule for ExtractOp.""" return @CancelInverses.rewrite_rule(quantum.MeasureOp) -def rewrite_mid_measure_op(self, op, rewriter): +def rewrite_mid_measure_op(self, op: quantum.MeasureOp, rewriter: PLPatternRewriter): """Rewrite rule for MeasureOp.""" return diff --git a/frontend/catalyst/python_interface/pass_api/passes.py b/frontend/catalyst/python_interface/pass_api/passes.py index 27919d2453..a84326c9b0 100644 --- a/frontend/catalyst/python_interface/pass_api/passes.py +++ b/frontend/catalyst/python_interface/pass_api/passes.py @@ -24,12 +24,12 @@ from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, - PatternRewriter, - PatternRewriteWalker, RewritePattern, op_type_rewrite_pattern, ) +from .pattern_rewriter import PLPatternRewriter, PLPatternRewriteWalker + def _update_type_hints(hint: type[Operation] | type[Operation]) -> Callable: r"""Update the signature of a ``match_and_rewrite`` method to use the provided operation @@ -77,7 +77,7 @@ def __init__(self, _pass): @op_type_rewrite_pattern @_update_type_hints(hint) - def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None: + def match_and_rewrite(self, op: Operation, rewriter: PLPatternRewriter) -> None: rewrite_rule(self._pass, op, rewriter) return _RewritePattern @@ -98,7 +98,7 @@ def __init__(self, recursive: bool = True, greedy: bool = False): @classmethod def rewrite_rule( cls, hint: type[Operation] | type[Operation] - ) -> Callable[[Operation, PatternRewriter], Callable]: + ) -> Callable[[Operation, PLPatternRewriter], Callable]: r"""Decorator to register a rewrite rule. The rewrite rule must have the following signature: @@ -106,7 +106,7 @@ def rewrite_rule( .. code-block:: python @PLModulePass.rewrite_rule(MyOperation) - def rewrite_myop(self, op: MyOperation, rewriter: PatternRewriter) -> None: + def rewrite_myop(self, op: MyOperation, rewriter: PLPatternRewriter) -> None: ... .. note:: @@ -121,7 +121,7 @@ def rewrite_myop(self, op: MyOperation, rewriter: PatternRewriter) -> None: Callable: a decorator to register the rewrite rule with the ModulePass """ - def decorator(rule: Callable[[Operation, PatternRewriter], None]) -> Callable: + def decorator(rule: Callable[[Operation, PLPatternRewriter], None]) -> Callable: rewrite_pattern = _create_rewrite_pattern(hint, rule) cls._rewrite_patterns[hint] = rewrite_pattern return rule @@ -156,10 +156,10 @@ def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: # pylint: disable= pattern = GreedyRewritePatternApplier( rewrite_patterns=[rp(self) for rp in self._rewrite_patterns.values()] ) - walker = PatternRewriteWalker(pattern=pattern, apply_recursively=self.recursive) + walker = PLPatternRewriteWalker(pattern=pattern, apply_recursively=self.recursive) walker.rewrite_module(op) else: for rp in self._rewrite_patterns.values(): - walker = PatternRewriteWalker(pattern=rp(self), apply_recursively=self.recursive) + walker = PLPatternRewriteWalker(pattern=rp(self), apply_recursively=self.recursive) walker.rewrite_module(op) diff --git a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py new file mode 100644 index 0000000000..026a154486 --- /dev/null +++ b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py @@ -0,0 +1,503 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pattern rewriter API for quantum compilation passes.""" + +from collections.abc import Sequence +from numbers import Number + +from pennylane import math, measurements, ops +from pennylane.exceptions import TransformError +from pennylane.operation import Operator +from xdsl.builder import ImplicitBuilder +from xdsl.dialects import arith, builtin, func, scf, tensor +from xdsl.ir import BlockArgument +from xdsl.ir import Operation as xOperation +from xdsl.ir import Region, SSAValue +from xdsl.pattern_rewriter import PatternRewriter, PatternRewriterListener, PatternRewriteWalker +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.utils import get_constant_from_ssa + +_named_observables = (ops.PauliX, ops.PauliY, ops.PauliZ, ops.Identity, ops.Hadamard) +_gate_like_ops = ( + quantum.CustomOp, + quantum.GlobalPhaseOp, + quantum.MultiRZOp, + # TODO: Uncomment once PCPhaseOp is added to Quantum dialect + # quantum.PCPhaseOp, + quantum.QubitUnitaryOp, +) + + +class PLPatternRewriter(PatternRewriter): + """A ``PatternRewriter`` with abstractions for quantum compilation passes. + + This is a subclass of ``xdsl.pattern_rewriter.PatternRewriter`` that exposes + methods to abstract away low-level pattern-rewriting details relevant to + quantum compilation passes. + """ + + def __init__(self, current_operation: xOperation): + super().__init__(current_operation) + + def get_qnode( + self, start_op: xOperation | None = None, get_func: bool = False + ) -> builtin.ModuleOp | func.FuncOp: + """Get the module corresponding to the QNode containing the given operation. + + The input operation, or the operation used to initialize the rewriter if an operation + is not provided, are used to search outer scopes until the module corresponding + to a QNode is found. + + .. note:: + + This method assumes that the module corresponding to a QNode will not contain any + modules in its body. + + Args: + start_op (xdsl.ir.Operation): The operation used to begin the search. If ``None``, + the operation used to initialize the rewriter will be used for the search. + get_func (bool): If ``True``, the FuncOp corresponding to the QNode will be returned + instead of the module. ``False`` by default. + + Returns: + ModuleOp | FuncOp: The QNode module surrounding the current operation, or the QNode + function if ``get_func`` is ``True``. + """ + current_op: xOperation = start_op or self.current_operation + while not isinstance(current_op, builtin.ModuleOp): + current_op = current_op.parent_op() + + qnode_func = None + for op in current_op.body.ops(): + if isinstance(op, func.FuncOp) and op.attributes.get("qnode", None): + qnode_func = op + break + + if qnode_func is None: + raise TransformError(f"{current_op} is not inside a QNode's scope.") + + return qnode_func if get_func else current_op + + def insert_constant(self, cst: Number, insertion_point: InsertPoint) -> xOperation: + """Create a scalar ConstantOp and insert it into the IR. + + Args: + cst (Number): The scalar to insert into the IR + insertion_point (InsertPoint): The point in the IR where the ``ConstantOp`` + that creates the constant SSAValue should be inserted + + Returns: + SSAValue: The SSA value corresponding to the constant + """ + data = [cst] + match cst: + case int(): + elem_type = builtin.IntegerType(64) + case float(): + elem_type = builtin.Float64Type() + case cst, bool(): + elem_type = builtin.IntegerType(1) + case complex(): + elem_type = builtin.ComplexType() + data = [[cst.real, cst.imag]] + case _: + raise TypeError(f"{cst} is not a valid type to insert as a constant.") + + type_ = builtin.TensorType(element_type=elem_type, shape=[]) + constAttr = builtin.DenseIntOrFPElementsAttr.from_list(type_, data) + constantOp = arith.ConstantOp(constAttr) + extractOp = tensor.ExtractOp(tensor=constantOp.result, indices=[], result_type=elem_type) + + self.insert_op(constantOp, insertion_point) + self.insert_op(extractOp, InsertPoint.after(constantOp)) + + return extractOp + + def get_num_qubits(self, insertion_point: InsertPoint) -> SSAValue[builtin.I64]: + """Get the number of qubits. + + All qubits available in the QNode, whether statically or dynamically allocated, + at a given point in the QNode are returned as a 64-bit integer SSAValue. + + Args: + insertion_point (InsertPoint): The point in the IR where the instruction + that gets the number of qubits should be inserted. This is necessary + because the number of qubits in a program can be different at different + points in the program when dynamically allocated qubits are present. + + Returns: + SSAValue[I64]: A 64-bit integer SSAValue corresponding to the number of allocated + qubits. + """ + numQubitsOp = quantum.NumQubitsOp() + self.insert_op(numQubitsOp, insertion_point=insertion_point) + return numQubitsOp.results[0] + + def get_num_shots(self, as_literal: bool = False) -> SSAValue[builtin.I64] | int | None: + """Get the number of shots. + + Args: + as_literal (bool): If ``True``, the shots will be returned as a Python + integer. If the shots are dynamic, the returned value will be -1. + If ``False``, an int64 ``SSAValue`` corresponding to the number of shots + will be returned. False by default + + Returns: + SSAValue[I64] | int | None: ``int`` if ``is_literal`` is ``True``, else an int64 + ``SSAValue``. If the execution is analytic, ``None`` will be returned. + """ + try: + qnode: func.FuncOp = self.get_qnode(get_func=True) + except TransformError as e: + raise TransformError( + "Cannot get the number of shots when rewriting an operation outside the " + "scope of a QNode." + ) from e + + # The qnode function always initializes a quantum device using the quantum.DeviceInitOp + # operation. + device_init = None + for op in qnode.body.ops: + if isinstance(op, quantum.DeviceInitOp): + device_init = op + break + + assert device_init is not None + + # If the device is **known** to be analytic, it will not have any operands. Note that even + # if the DeviceInitOp has shots as its operand, it may be analytic if the shots operand is + # a constant == 0. + if len(device_init.operands) == 0: + return None + + shots = device_init.operands[0] + if not as_literal: + return shots + + # If shots are dynamic, they will **always** be an argument to the QNode, else + # they will be static, and created using a constant-like operation. + if isinstance(shots, BlockArgument): + return -1 + + shots = get_constant_from_ssa(shots) + return None if shots == 0 else shots + + def erase_gate(self, op: xOperation) -> None: + """Erase a quantum gate. + + Safely erase a quantum gate from the module being transformed. This method automatically + handles and pre-processing required before safely erasing an operation. To erase quantum + gates, which include ``CustomOp``, ``GlobalPhaseOp``, ``MultiRZOp``, ``PCPhaseOp``, and + ``QubitUnitaryOp``, it is recommended to use this method instead of ``erase_op``. + + Args: + op (xdsl.ir.Operation): The operation to erase + """ + if not isinstance(op, _gate_like_ops): + raise TypeError( + f"Cannot erase {op}. 'PLPatternRewriter.erase_op' can only erase " + "gate-like operations." + ) + + # GlobalPhaseOp does not have any target qubits + in_qubits = ( + op.in_ctrl_qubits + if isinstance(op, quantum.GlobalPhaseOp) + else (op.in_qubits + op.in_ctrl_qubits) + ) + self.replace_op(op, (), in_qubits) + + def insert_mid_measure( + self, mcm: measurements.MidMeasureMP, insertion_point: InsertPoint + ) -> SSAValue[quantum.QubitType]: + """Insert a PL mid-circuit measurement into the IR at the provided insertion point. + + Args: + mcm (pennylane.ops.MidMeasureMP): The mid-circuit measurement to insert into + the IR. Note that the measurement qubit must be an SSAValue. + insertion_point (InsertPoint): The point in the IR where the operation must + be inserted. + + Returns: + xdsl.ir.SSAValue[quantum.QubitType]: The qubit returned by the mid-circuit measurement. + """ + in_qubit: SSAValue[quantum.QubitType] = mcm.wires[0] + midMeasureOp = quantum.MeasureOp(in_qubit=in_qubit, postselect=mcm.postselect) + self.insert_op(midMeasureOp, insertion_point=insertion_point) + out_qubit = midMeasureOp.out_qubit + + # If resetting, we need to insert a conditional statement that applies a PauliX + # if we measured |1>. The else block just yields a qubit. + if mcm.reset: + true_region = Region() + with ImplicitBuilder(true_region): + gate = quantum.CustomOp(gate_name="PauliX", in_qubits=(midMeasureOp.out_qubit,)) + _ = scf.YieldOp(gate.out_qubits[0]) + + false_region = Region() + with ImplicitBuilder(false_region): + _ = scf.YieldOp(midMeasureOp.out_qubit) + + ifOp = scf.IfOp( + cond=midMeasureOp.mres, + return_types=(quantum.QubitType(),), + true_region=true_region, + false_region=false_region, + ) + self.insert_op(ifOp, InsertPoint.after(midMeasureOp)) + out_qubit = ifOp.results[0] + + mcm_qubit_uses = [use for use in in_qubit.uses if use.operation != midMeasureOp] + in_qubit.replace_by_if(out_qubit, lambda use: use.operation != midMeasureOp) + for use in mcm_qubit_uses: + self.notify_op_modified(use.operation) + + return out_qubit + + @staticmethod + def _get_gate_base(gate: Operator): + """Get the base op of a gate.""" + + if isinstance(gate, ops.Controlled): + base_gate, ctrl_wires, ctrl_vals, adjoint = PLPatternRewriter._get_gate_base(gate.base) + ctrl_wires = tuple(gate.control_wires) + tuple(ctrl_wires) + ctrl_vals = tuple(gate.control_values) + tuple(ctrl_vals) + return base_gate, ctrl_wires, ctrl_vals, adjoint + + if isinstance(gate, ops.Adjoint): + base_gate, ctrl_wires, ctrl_vals, adjoint = PLPatternRewriter._get_gate_base(gate.base) + adjoint = adjoint ^ True + return base_gate, tuple(ctrl_wires), tuple(ctrl_vals), adjoint + + return gate, (), (), False + + # TODO: fix too-many-statements warning + # pylint: disable=too-many-statements, too-many-branches + def insert_gate( + self, gate: Operator, insertion_point: InsertPoint, params: Sequence[SSAValue] | None = None + ) -> xOperation: + r"""Insert a PL gate into the IR at the provided insertion point. + + .. note:: + + Inserting state-preparation operations is currently not supported. + + Args: + gate (~pennylane.operation.Operator): The gate to insert. The wires of the gate must be + ``QubitType`` ``SSAValue``\ s. + insertion_point (InsertPoint): The point where the operation should be inserted. + params (Sequence[SSAValue] | None): For parametric gates, the list of ``SSAValue``\ s that + should be used as the gate's operands. If not provided, the parameters to ``gate`` will + be inserted as constants into the program. + + Returns: + xdsl.ir.Operation: The xDSL operation corresponding to the gate being inserted. + """ + # TODO: Add support for StatePrep, BasisState + gate, ctrl_wires, ctrl_vals, adjoint = self._get_gate_base(gate) + op_args = {} + + # If the gate is a QubitUnitary and an SSA tensor is not provided as its matrix, then + # we need to create a constant matrix SSAValue using the gate's matrix. + if isinstance(gate, ops.QubitUnitary) and not params: + mat = gate.matrix() + mat_attr = builtin.DenseIntOrFPElementsAttr.from_list( + builtin.TensorType( + builtin.ComplexType(builtin.Float64Type()), shape=math.shape(mat) + ), + mat, + ) + constantOp = arith.ConstantOp(value=mat_attr) + self.insert_op(constantOp, insertion_point=insertion_point) + insertion_point = InsertPoint.after(constantOp) + params = [constantOp.results[0]] + + # Create static parameters + elif not params: + params = [] + for d in gate.data: + try: + d = float(d) + except ValueError as e: + raise TransformError( + "Only values that can be cast into floats can be used as gate " + f"parameters. Got {d}." + ) from e + constOp = self.insert_constant(d, insertion_point) + params.append(constOp.results[0]) + insertion_point = InsertPoint.after(constOp) + + # TODO: Uncomment after PCPhaseOp is added to Quantum dialect + # # PCPhase has a `dim` hyperparameter which also needs to be inserted into the IR. + # if isinstance(gate, ops.PCPhase): + # constOp = self.insert_constant(float(gate.hyperparameters["dimension"][0]), insertion_point) + # params.append(constOp.results[0]) + # insertion_point = InsertPoint.after(constOp) + + # Different gate types may be represented in MLIR by different operations, which may + # take slightly different arguments + match type(gate): + case ops.GlobalPhase: + op_class = quantum.GlobalPhaseOp + assert len(params) == 1 + op_args["params"] = params[0] + case ops.MultiRZ: + op_class = quantum.MultiRZOp + assert len(params) == 1 + op_args["theta"] = params[0] + case ops.QubitUnitary: + op_class = quantum.QubitUnitaryOp + assert len(params) == 1 + op_args["matrix"] = params[0] + # TODO: Uncomment after PCPhaseOp is added to Quantum dialect + # case ops.PCPhase: + # op_class = quantum.PCPhaseOp + # op_args["theta"] = params[0] + # op_args["dim"] = params[1] + case _: + op_class = quantum.CustomOp + op_args["gate_name"] = gate.name + op_args["params"] = params + + # Add qubits/control qubits to args. GlobalPhaseOp does not take qubits, only + # control qubits + if not isinstance(gate, ops.GlobalPhase): + op_args["in_qubits"] = tuple(gate.wires) + op_args["in_ctrl_qubits"] = tuple(ctrl_wires) if ctrl_wires else None + in_ctrl_values = None + + # Add ctrl values to args + if ctrl_vals: + true_cst = None + false_cst = None + if any(ctrl_vals): + true_cst = self.insert_constant(True, insertion_point=insertion_point) + insertion_point = InsertPoint.after(true_cst) + if not all(ctrl_vals): + false_cst = self.insert_constant(False, insertion_point=insertion_point) + insertion_point = InsertPoint.after(false_cst) + in_ctrl_values = tuple( + true_cst.results[0] if v else false_cst.results[0] for v in ctrl_vals + ) + op_args["in_ctrl_values"] = in_ctrl_values + op_args["adjoint"] = adjoint + + gateOp = op_class(**op_args) + self.insert_op(gateOp, insertion_point=insertion_point) + + # Use getattr for in/out_qubits because GlobalPhaseOp does not have in/out_qubits + for iq, oq in zip( + getattr(gateOp, "in_qubits", ()) + tuple(gateOp.in_ctrl_qubits), + getattr(gateOp, "out_qubits", ()) + tuple(gateOp.out_ctrl_qubits), + strict=True, + ): + in_qubit_uses = [use for use in iq.uses if use.operation != gateOp] + iq.replace_by_if(oq, lambda use: use.operation != gateOp) + for use in in_qubit_uses: + self.notify_op_modified(use.operation) + # self.ctx.update_qubit(iq, oq) + + return gateOp + + # TODO: Finish implementation + def swap_gates(self, op1: xOperation, op2: xOperation) -> None: + """Swap two operations in the IR. + + Args: + op1 (xdsl.ir.Operation): First operation for the swap + op2 (xdsl.ir.Operation): Second operation for the swap + """ + if not (op1 in _gate_like_ops and op2 in _gate_like_ops): + raise TransformError(f"Can only swap gates. Got {op1}, {op2}") + + if set(op1.results) | set(op2.operands): + pass + elif set(op1.operands) | set(op2.results): + op1, op2 = op2, op1 + else: + raise TransformError("Cannot swap operations that are not SSA neighbours.") + + # Walk the IR forwards or backwards from the operation with less wires to check if there are + # any ops that use the same wires as op1 or op2 + n_vals1 = [r for r in op1.results if isinstance(r.type, quantum.QubitType)] + n_vals2 = [o for o in op2.operands if isinstance(o.type, quantum.QubitType)] + # If op1 has less results than op2 has operands, do forward traversal from op1 + # Else, do backward traversal from op2. + if n_vals1 <= n_vals2: + pass + else: + pass + + # Update uses for the swap + new_op2 = type(op2).create( + operands=(), + result_types=(), + properties=op2.properties, + attributes=op2.attributes, + successors=op2.successors, + regions=op2.regions, + ) + self.insert_op(new_op2, insertion_point=InsertPoint.before(op1)) + + # Reorder the block so that the linear order is valid + + +# pylint: disable=too-few-public-methods +class PLPatternRewriteWalker(PatternRewriteWalker): + """A ``PatternRewriteWalker`` for traversing and rewriting modules. + + This is a subclass of ``xdsl.pattern_rewriter.PatternRewriteWalker that uses a custom + rewriter that contains abstractions for quantum compilation passes.""" + + def _process_worklist(self, listener: PatternRewriterListener) -> bool: + """ + Process the worklist until it is empty. + Returns true if any modification was done. + """ + rewriter_has_done_action = False + + # Handle empty worklist + op = self._worklist.pop() + if op is None: + return rewriter_has_done_action + + # Create a rewriter on the first operation + # Here, we use our custom rewriter instead of the default PatternRewriter. + rewriter = PLPatternRewriter(op) + rewriter.extend_from_listener(listener) + + # do/while loop + while True: + # Reset the rewriter on `op` + rewriter.has_done_action = False + rewriter.current_operation = op + rewriter.insertion_point = InsertPoint.before(op) + rewriter.name_hint = None + + # Apply the pattern on the operation + try: + self.pattern.match_and_rewrite(op, rewriter) + except Exception as err: # pylint: disable=broad-exception-caught + op.emit_error( + f"Error while applying pattern: {err}", + underlying_error=err, + ) + rewriter_has_done_action |= rewriter.has_done_action + + # If the worklist is empty, we are done + op = self._worklist.pop() + if op is None: + return rewriter_has_done_action From 7df0af8b51657043acba8e16c309393bb24656eb Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 27 Nov 2025 15:01:44 -0500 Subject: [PATCH 02/10] Add get_active_qubits method --- .../pass_api/pattern_rewriter.py | 148 +++++++++++++++++- 1 file changed, 147 insertions(+), 1 deletion(-) diff --git a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py index 026a154486..1788b1fe7c 100644 --- a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py +++ b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py @@ -14,14 +14,17 @@ """Pattern rewriter API for quantum compilation passes.""" from collections.abc import Sequence +from dataclasses import dataclass, field from numbers import Number +from typing import TypeAlias +from uuid import UUID, uuid4 from pennylane import math, measurements, ops from pennylane.exceptions import TransformError from pennylane.operation import Operator from xdsl.builder import ImplicitBuilder from xdsl.dialects import arith, builtin, func, scf, tensor -from xdsl.ir import BlockArgument +from xdsl.ir import Block, BlockArgument from xdsl.ir import Operation as xOperation from xdsl.ir import Region, SSAValue from xdsl.pattern_rewriter import PatternRewriter, PatternRewriterListener, PatternRewriteWalker @@ -41,6 +44,33 @@ ) +@dataclass +class DynamicWire: + """Dynamic wire label wrapper.""" + + var: SSAValue[builtin.i64] | None = None + """SSA value representing the wire's dynamic index into the quantum register.""" + + _id: UUID = field(default_factory=uuid4, init=False) + """Unique ID representing the qubit. This is only used if the wire was allocated dynamically + (``quantum.AllocQubitOp``) or the qubit originated as a ``BlockArgument``.""" + + def __hash__(self) -> int: + if self.var: + return hash(self.var) + + return hash(self._id) + + def __eq__(self, other: "DynamicWire") -> bool: + if self.var and other.var: + return self.var == other.var + + return self._id == other._id + + +WireLabelLike: TypeAlias = int | DynamicWire + + class PLPatternRewriter(PatternRewriter): """A ``PatternRewriter`` with abstractions for quantum compilation passes. @@ -91,6 +121,122 @@ def get_qnode( return qnode_func if get_func else current_op + def get_qreg(self, insertion_point: InsertPoint) -> quantum.QuregSSAValue | None: + """Get the active quantum register at a given point in the program. Returns ``None`` + If there are no active quantum registers. Currently, we assume that there can only be + one active quantum register at any given point in the program.""" + cur_op = insertion_point.insert_before + qreg = None + + while cur_op.prev_op is not None: + if isinstance(cur_op, quantum.DeallocOp): + break + + cur_op = cur_op.prev_op + for r in cur_op.results: + if isinstance(r.type, quantum.QuregType): + qreg = r + break + + return qreg + + def _find_qubit_root( + self, qubit: quantum.QubitSSAValue + ) -> tuple[quantum.ExtractOp | quantum.AllocQubitOp | Block, list[quantum.QubitSSAValue]]: + """Find the root operation that created a qubit.""" + root = None + owner = qubit.owner + seen_qubits = set() + + while True: + seen_qubits.add(qubit) + + if isinstance(owner, (Block, quantum.AllocQubitOp, quantum.ExtractOp)): + root = owner + break + + if isinstance(owner, _gate_like_ops): + qb_idx = qubit.index + n_classical_operands = ( + len(owner.operands) + - len(owner.in_ctrl_qubits) + - len(getattr(owner, "in_qubits", [])) + ) + qubit = owner.operands[n_classical_operands + qb_idx] + owner = qubit.owner + + raise ValueError(f"Cannot handle {owner} type yet for finding a qubit's root.") + + return root, seen_qubits + + def get_qreg_idx_from_qubit(self, qubit: quantum.QubitSSAValue) -> int | DynamicWire | None: + """Get the index to the quantum register to which a qubit corresponds. + + If the index is unknown at compile-time, a ``DynamicWire`` will be returned with + a handle to the SSA value corresponding to the index. If the qubit was dynamically + allocated, or originated from an argument, ``None`` will be returned. + """ + if not isinstance(qubit.type, quantum.QubitType): + raise TypeError(f"The input must be a QubitType SSAValue, got {qubit}.") + + root, _ = self._find_qubit_root(qubit) + + if isinstance(root, quantum.ExtractOp): + if (idx_attr := getattr(root, "idx_attr", None)) is not None: + return idx_attr.data + elif (idx := get_constant_from_ssa(root.idx)) is not None: + return idx + + return DynamicWire(var=root.idx) + + # If the qubit's root is an AllocQubitOp, the qubit does not correspond to a + # qreg index. Else, it originates from a BlockArgument, in which case, its + # corresponding index, if there is one, cannot be inferred. + return None + + def get_active_qubits( + self, insertion_point: InsertPoint + ) -> dict[WireLabelLike | quantum.QubitSSAValue, quantum.QubitSSAValue | WireLabelLike]: + """Get a bidirectional map from wire indices to active qubits corresponding to + those wire labels and vice versa.""" + # Starting constraint: deal with flat regions + qubit_map: dict[int | DynamicWire, quantum.QubitSSAValue] = {} + seen_qubits: set[quantum.QubitSSAValue] = set() + cur_op: xOperation = insertion_point.insert_before + + while not (isinstance(cur_op.prev_op, quantum.AllocOp) or cur_op.prev_op is None): + cur_op = cur_op.prev_op + + if isinstance(cur_op, quantum.InsertOp): + seen_qubits.add(cur_op.qubit) + continue + + for r in cur_op.results: + if not isinstance(r.type, quantum.QubitType) or r in seen_qubits: + continue + + root, cur_seen_qubits = self._find_qubit_root(r) + seen_qubits |= cur_seen_qubits + insert_idx = None + + if isinstance(root, quantum.ExtractOp): + if (idx_attr := getattr(root, "idx_attr", None)) is not None: + insert_idx = idx_attr + elif (idx := get_constant_from_ssa(root.idx)) is not None: + insert_idx = idx + else: + insert_idx = DynamicWire(var=root.idx) + + # Qubit's root is either AllocQubitOp or it is a block argument. + # Either way, the qubit does not correspond to an index + else: + insert_idx = DynamicWire() + + qubit_map[insert_idx] = r + qubit_map[r] = insert_idx + + return qubit_map + def insert_constant(self, cst: Number, insertion_point: InsertPoint) -> xOperation: """Create a scalar ConstantOp and insert it into the IR. From f6ae456ea232d5c54301b2ff4d223c05f25b5cf1 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 27 Nov 2025 16:59:38 -0500 Subject: [PATCH 03/10] Fix get_active_qubits --- .../catalyst/python_interface/pass_api/pattern_rewriter.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py index 1788b1fe7c..dfdb6a04ba 100644 --- a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py +++ b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py @@ -165,7 +165,8 @@ def _find_qubit_root( qubit = owner.operands[n_classical_operands + qb_idx] owner = qubit.owner - raise ValueError(f"Cannot handle {owner} type yet for finding a qubit's root.") + else: + raise ValueError(f"Cannot handle {owner} type yet for finding a qubit's root.") return root, seen_qubits @@ -208,7 +209,7 @@ def get_active_qubits( cur_op = cur_op.prev_op if isinstance(cur_op, quantum.InsertOp): - seen_qubits.add(cur_op.qubit) + seen_qubits |= self._find_qubit_root(cur_op.qubit)[1] continue for r in cur_op.results: @@ -221,7 +222,7 @@ def get_active_qubits( if isinstance(root, quantum.ExtractOp): if (idx_attr := getattr(root, "idx_attr", None)) is not None: - insert_idx = idx_attr + insert_idx = idx_attr.value.data elif (idx := get_constant_from_ssa(root.idx)) is not None: insert_idx = idx else: From 2db52431dd96316e851f124a0966f361ce646e53 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 27 Nov 2025 17:00:40 -0500 Subject: [PATCH 04/10] Fix dev comment --- .../catalyst/python_interface/pass_api/pattern_rewriter.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py index dfdb6a04ba..9cc25d2fc5 100644 --- a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py +++ b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py @@ -228,8 +228,8 @@ def get_active_qubits( else: insert_idx = DynamicWire(var=root.idx) - # Qubit's root is either AllocQubitOp or it is a block argument. - # Either way, the qubit does not correspond to an index + # Qubit's root is either AllocQubitOp or a Block. Either way, + # the qubit does not correspond to an index. else: insert_idx = DynamicWire() From b6c379e20320d7b69de92db8bec1e5414d1d1836 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 27 Nov 2025 17:04:34 -0500 Subject: [PATCH 05/10] Uncomment PCPhaseOp --- .../catalyst/python_interface/pass_api/pattern_rewriter.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py index 9cc25d2fc5..7547d987c9 100644 --- a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py +++ b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py @@ -38,8 +38,7 @@ quantum.CustomOp, quantum.GlobalPhaseOp, quantum.MultiRZOp, - # TODO: Uncomment once PCPhaseOp is added to Quantum dialect - # quantum.PCPhaseOp, + quantum.PCPhaseOp, quantum.QubitUnitaryOp, ) From 721f9914cffdc9c0b520dda4edc793464eb6c066 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 27 Nov 2025 17:05:51 -0500 Subject: [PATCH 06/10] Uncomment PCPhaseOp --- .../pass_api/pattern_rewriter.py | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py index 7547d987c9..10e3ad6506 100644 --- a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py +++ b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py @@ -486,12 +486,13 @@ def insert_gate( params.append(constOp.results[0]) insertion_point = InsertPoint.after(constOp) - # TODO: Uncomment after PCPhaseOp is added to Quantum dialect - # # PCPhase has a `dim` hyperparameter which also needs to be inserted into the IR. - # if isinstance(gate, ops.PCPhase): - # constOp = self.insert_constant(float(gate.hyperparameters["dimension"][0]), insertion_point) - # params.append(constOp.results[0]) - # insertion_point = InsertPoint.after(constOp) + # PCPhase has a `dim` hyperparameter which also needs to be inserted into the IR. + if isinstance(gate, ops.PCPhase): + constOp = self.insert_constant( + float(gate.hyperparameters["dimension"][0]), insertion_point + ) + params.append(constOp.results[0]) + insertion_point = InsertPoint.after(constOp) # Different gate types may be represented in MLIR by different operations, which may # take slightly different arguments @@ -508,11 +509,10 @@ def insert_gate( op_class = quantum.QubitUnitaryOp assert len(params) == 1 op_args["matrix"] = params[0] - # TODO: Uncomment after PCPhaseOp is added to Quantum dialect - # case ops.PCPhase: - # op_class = quantum.PCPhaseOp - # op_args["theta"] = params[0] - # op_args["dim"] = params[1] + case ops.PCPhase: + op_class = quantum.PCPhaseOp + op_args["theta"] = params[0] + op_args["dim"] = params[1] case _: op_class = quantum.CustomOp op_args["gate_name"] = gate.name From 3807b4ea86f140918129bb00ea3d16b45152e40c Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Thu, 27 Nov 2025 17:06:56 -0500 Subject: [PATCH 07/10] Raise NotImplementedError in swap_gates --- .../pass_api/pattern_rewriter.py | 68 ++++++++++--------- 1 file changed, 35 insertions(+), 33 deletions(-) diff --git a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py index 10e3ad6506..c86c31644e 100644 --- a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py +++ b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py @@ -566,39 +566,41 @@ def swap_gates(self, op1: xOperation, op2: xOperation) -> None: op1 (xdsl.ir.Operation): First operation for the swap op2 (xdsl.ir.Operation): Second operation for the swap """ - if not (op1 in _gate_like_ops and op2 in _gate_like_ops): - raise TransformError(f"Can only swap gates. Got {op1}, {op2}") - - if set(op1.results) | set(op2.operands): - pass - elif set(op1.operands) | set(op2.results): - op1, op2 = op2, op1 - else: - raise TransformError("Cannot swap operations that are not SSA neighbours.") - - # Walk the IR forwards or backwards from the operation with less wires to check if there are - # any ops that use the same wires as op1 or op2 - n_vals1 = [r for r in op1.results if isinstance(r.type, quantum.QubitType)] - n_vals2 = [o for o in op2.operands if isinstance(o.type, quantum.QubitType)] - # If op1 has less results than op2 has operands, do forward traversal from op1 - # Else, do backward traversal from op2. - if n_vals1 <= n_vals2: - pass - else: - pass - - # Update uses for the swap - new_op2 = type(op2).create( - operands=(), - result_types=(), - properties=op2.properties, - attributes=op2.attributes, - successors=op2.successors, - regions=op2.regions, - ) - self.insert_op(new_op2, insertion_point=InsertPoint.before(op1)) - - # Reorder the block so that the linear order is valid + raise NotImplementedError + + # if not (op1 in _gate_like_ops and op2 in _gate_like_ops): + # raise TransformError(f"Can only swap gates. Got {op1}, {op2}") + + # if set(op1.results) | set(op2.operands): + # pass + # elif set(op1.operands) | set(op2.results): + # op1, op2 = op2, op1 + # else: + # raise TransformError("Cannot swap operations that are not SSA neighbours.") + + # # Walk the IR forwards or backwards from the operation with less wires to check if there are + # # any ops that use the same wires as op1 or op2 + # n_vals1 = [r for r in op1.results if isinstance(r.type, quantum.QubitType)] + # n_vals2 = [o for o in op2.operands if isinstance(o.type, quantum.QubitType)] + # # If op1 has less results than op2 has operands, do forward traversal from op1 + # # Else, do backward traversal from op2. + # if n_vals1 <= n_vals2: + # pass + # else: + # pass + + # # Update uses for the swap + # new_op2 = type(op2).create( + # operands=(), + # result_types=(), + # properties=op2.properties, + # attributes=op2.attributes, + # successors=op2.successors, + # regions=op2.regions, + # ) + # self.insert_op(new_op2, insertion_point=InsertPoint.before(op1)) + + # # Reorder the block so that the linear order is valid # pylint: disable=too-few-public-methods From 5a47713555e4756ef79a354b6edfd9508a0eb982 Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Fri, 28 Nov 2025 11:15:32 -0500 Subject: [PATCH 08/10] Remove artifact --- frontend/catalyst/python_interface/pass_api/pattern_rewriter.py | 1 - 1 file changed, 1 deletion(-) diff --git a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py index c86c31644e..83cb48ed8b 100644 --- a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py +++ b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py @@ -554,7 +554,6 @@ def insert_gate( iq.replace_by_if(oq, lambda use: use.operation != gateOp) for use in in_qubit_uses: self.notify_op_modified(use.operation) - # self.ctx.update_qubit(iq, oq) return gateOp From 9921ae17b7b5e1d7608389f1205f93e119521dfb Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 1 Dec 2025 16:10:39 -0500 Subject: [PATCH 09/10] convert insert_observable into singledispatchmethod; insert_measurement in progress --- .../pass_api/pattern_rewriter.py | 280 +++++++++++++++++- 1 file changed, 265 insertions(+), 15 deletions(-) diff --git a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py index 83cb48ed8b..520d21c263 100644 --- a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py +++ b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py @@ -15,12 +15,14 @@ from collections.abc import Sequence from dataclasses import dataclass, field +from functools import singledispatchmethod from numbers import Number from typing import TypeAlias from uuid import UUID, uuid4 from pennylane import math, measurements, ops from pennylane.exceptions import TransformError +from pennylane.measurements import MeasurementProcess from pennylane.operation import Operator from xdsl.builder import ImplicitBuilder from xdsl.dialects import arith, builtin, func, scf, tensor @@ -123,15 +125,11 @@ def get_qnode( def get_qreg(self, insertion_point: InsertPoint) -> quantum.QuregSSAValue | None: """Get the active quantum register at a given point in the program. Returns ``None`` If there are no active quantum registers. Currently, we assume that there can only be - one active quantum register at any given point in the program.""" - cur_op = insertion_point.insert_before + at most one active quantum register at any given point in the program.""" + cur_op = insertion_point.insert_before.prev_op qreg = None - while cur_op.prev_op is not None: - if isinstance(cur_op, quantum.DeallocOp): - break - - cur_op = cur_op.prev_op + while not (cur_op is None or isinstance(cur_op, quantum.DeallocOp)): for r in cur_op.results: if isinstance(r.type, quantum.QuregType): qreg = r @@ -169,12 +167,13 @@ def _find_qubit_root( return root, seen_qubits - def get_qreg_idx_from_qubit(self, qubit: quantum.QubitSSAValue) -> int | DynamicWire | None: + # TODO: Improve docstring + def get_idx_from_qubit(self, qubit: quantum.QubitSSAValue) -> int | SSAValue | None: """Get the index to the quantum register to which a qubit corresponds. - If the index is unknown at compile-time, a ``DynamicWire`` will be returned with - a handle to the SSA value corresponding to the index. If the qubit was dynamically - allocated, or originated from an argument, ``None`` will be returned. + If the index is unknown at compile-time, an SSA value corresponding to the index + will be returned. If the qubit was dynamically allocated, or originated from a + block argument, ``None`` will be returned. """ if not isinstance(qubit.type, quantum.QubitType): raise TypeError(f"The input must be a QubitType SSAValue, got {qubit}.") @@ -183,17 +182,18 @@ def get_qreg_idx_from_qubit(self, qubit: quantum.QubitSSAValue) -> int | Dynamic if isinstance(root, quantum.ExtractOp): if (idx_attr := getattr(root, "idx_attr", None)) is not None: - return idx_attr.data + return idx_attr.value.data elif (idx := get_constant_from_ssa(root.idx)) is not None: return idx - return DynamicWire(var=root.idx) + return root.idx - # If the qubit's root is an AllocQubitOp, the qubit does not correspond to a - # qreg index. Else, it originates from a BlockArgument, in which case, its + # If the qubit's root is an ``AllocQubitOp``, the qubit does not correspond to a + # qreg index. Else, it originates from a ``BlockArgument``, in which case, its # corresponding index, if there is one, cannot be inferred. return None + # TODO: Improve docstring def get_active_qubits( self, insertion_point: InsertPoint ) -> dict[WireLabelLike | quantum.QubitSSAValue, quantum.QubitSSAValue | WireLabelLike]: @@ -237,6 +237,30 @@ def get_active_qubits( return qubit_map + def insert_active_qubits( + self, + insertion_point: InsertPoint, + qreg: quantum.QuregSSAValue | None, + active_qubits: dict | None = None, + ) -> quantum.QuregSSAValue: + """Insert all currently active qubits into the quantum register.""" + if active_qubits is None: + active_qubits = self.get_active_qubits(insertion_point) + if qreg is None: + qreg = self.get_qreg(insertion_point) + + for idx, qubit in active_qubits.items(): + if isinstance(idx, (int, DynamicWire)): + if isinstance(idx, DynamicWire) and idx.var is not None: + idx = idx.var + + insertOp = quantum.InsertOp(in_qreg=qreg, idx=idx, qubit=qubit) + self.insert_op(insertOp, insertion_point=insertion_point) + insertion_point = InsertPoint.after(insertOp) + qreg = insertOp.results[0] + + return qreg + def insert_constant(self, cst: Number, insertion_point: InsertPoint) -> xOperation: """Create a scalar ConstantOp and insert it into the IR. @@ -557,6 +581,232 @@ def insert_gate( return gateOp + @singledispatchmethod + def insert_observable( + self, obs: Operator, insertion_point: InsertPoint, params: Sequence[SSAValue] | None = None + ) -> xOperation: + """Insert a PL observable into the IR at the provided insertion point.""" + if isinstance(obs, _named_observables): + in_qubit = obs.wires[0] + namedObsOp = quantum.NamedObsOp( + in_qubit, quantum.NamedObservableAttr(getattr(quantum.NamedObservable, obs.name)) + ) + self.insert_op(namedObsOp, insertion_point=insertion_point) + return namedObsOp + + raise TransformError( + f"The observable {type(obs).__name__} cannot be inserted into the module." + ) + + @insert_observable.register(ops.Hermitian) + def _insert_hermitian( + self, + obs: ops.Hermitian, + insertion_point: InsertPoint, + params: Sequence[SSAValue] | None = None, + ) -> xOperation: + """Insert a qml.Hermitian observable into the IR at the provided insertion point.""" + if params: + mat_op = params[0] + else: + mat = obs.matrix() + mat_attr = builtin.DenseIntOrFPElementsAttr.from_list( + builtin.TensorType( + builtin.ComplexType(builtin.Float64Type()), shape=math.shape(mat) + ), + mat, + ) + mat_op = arith.ConstantOp(value=mat_attr) + self.insert_op(mat_op, insertion_point=insertion_point) + insertion_point = InsertPoint.after(mat_op) + + in_qubits = tuple(obs.wires) + hermitianOp = quantum.HermitianOp( + operands=(mat_op.results[0], in_qubits), result_types=(quantum.ObservableType(),) + ) + self.insert_op(hermitianOp, insertion_point=insertion_point) + return hermitianOp + + @insert_observable.register(ops.Prod) + def _insert_prod( + self, obs: ops.Prod, insertion_point: InsertPoint, params: Sequence[SSAValue] | None = None + ) -> xOperation: + """Insert a qml.ops.Prod observable into the IR at the provided insertion point.""" + operands = [] + for o in obs.operands: + cur_obs = self.insert_observable(o, insertion_point, params=params) + operands.append(cur_obs.results[0]) + insertion_point = InsertPoint.after(cur_obs) + prodOp = quantum.TensorOp(operands=operands, result_types=(quantum.ObservableType(),)) + self.insert_op(prodOp, insertion_point=insertion_point) + return prodOp + + @insert_observable.register(ops.Sum) + def _insert_sum( + self, obs: ops.Sum, insertion_point: InsertPoint, params: Sequence[SSAValue] | None = None + ) -> xOperation: + """Insert a qml.ops.Sum observable into the IR at the provided insertion point.""" + # Create static tensor for coefficients + if params: + coeffs = params[0] + ops = obs.operands + else: + coeffs, ops = obs.terms() + coeffs_attr = builtin.DenseIntOrFPElementsAttr.from_list( + builtin.TensorType(builtin.Float64Type(), shape=(len(coeffs),)), coeffs + ) + constantOp = arith.ConstantOp(value=coeffs_attr) + self.insert_op(constantOp, insertion_point=insertion_point) + insertion_point = InsertPoint.after(constantOp) + coeffs = constantOp.results[0] + + _ops = [] + for o in ops: + cur_obs = self.insert_observable(o, insertion_point, params=params[1:]) + _ops.append(cur_obs.results[0]) + insertion_point = InsertPoint.after(cur_obs) + + hamiltonianOp = quantum.HamiltonianOp( + operands=(coeffs, _ops), result_types=(quantum.ObservableType(),) + ) + return hamiltonianOp + + @insert_observable.register(ops.SProd) + def _insert_sprod( + self, obs: ops.SProd, insertion_point: InsertPoint, params: Sequence[SSAValue] | None = None + ) -> xOperation: + """Insert a qml.ops.SProd observable into the IR at the provided insertion point.""" + if params: + coeff = params[0] + else: + coeffs_attr = builtin.DenseIntOrFPElementsAttr.from_list( + builtin.TensorType(builtin.Float64Type(), shape=(1,)), [obs.scalar] + ) + constantOp = arith.ConstantOp(value=coeffs_attr) + self.insert_op(constantOp, insertion_point=insertion_point) + insertion_point = InsertPoint.after(constantOp) + coeff = constantOp.results[0] + + base_obs = self.insert_observable(obs.base, insertion_point, params=params[1:]) + hamiltonianOp = quantum.HamiltonianOp( + operands=(coeff, (base_obs.results[0],)), result_types=(quantum.ObservableType(),) + ) + self.insert_op(hamiltonianOp, insertion_point=insertion_point) + return hamiltonianOp + + # TODO: Improve docstring + def insert_measurement( + self, mp: MeasurementProcess, insertion_point: InsertPoint + ) -> xOperation: + """Insert a measurement operation into the program.""" + if mp.mv: + raise TransformError( + "Inserting measurements that collect statistics on mid-circuit measurements " + "is currently not supported." + ) + + meas_op = None + qreg = None + + if mp.obs: + # Measurement on observable + obs = self.insert_observable(mp.obs, insertion_point) + n_qubits = len(mp.obs.wires) + + insertion_point = InsertPoint.after(obs) + + if mp.wires: + # Measurement on some wires + obs = quantum.ComputationalBasisOp( + operands=(tuple(mp.wires), None), + result_types=(quantum.ObservableType()), + ) + n_qubits = len(mp.wires) + self.insert_op(obs, insertion_point=insertion_point) + + insertion_point = InsertPoint.after(obs) + + else: + # Measurement on all wires + qreg = self.get_qreg(insertion_point) + obs = quantum.ComputationalBasisOp( + operands=((), qreg), result_types=(quantum.ObservableType(),) + ) + self.insert_op(obs, insertion_point=insertion_point) + insertion_point = InsertPoint.after(obs) + n_qubits = self.get_num_qubits(InsertPoint.before(obs)) + + # # Create the measurement xDSL operation + # match type(mp): + # case measurements.ExpectationMP: + # measurementOp = quantum.ExpvalOp(obs=obs) + + # case measurements.VarianceMP: + # measurementOp = quantum.VarianceOp(obs=obs) + + # case measurements.StateMP: + # # For now, we assume that there is no input MemRefType + # tensor_shape, dynamic_shape, insertion_point = self._resolve_dynamic_shape( + # n_qubits, insertion_point + # ) + # measurementOp = quantum.StateOp( + # operands=(obs, dynamic_shape, None), + # result_types=( + # builtin.TensorType( + # element_type=builtin.ComplexType(builtin.Float64Type()), + # shape=tensor_shape, + # ) + # ), + # ) + + # case measurements.ProbabilityMP: + # # For now, we assume that there is no input MemRefType + # tensor_shape, dynamic_shape, insertion_point = self._resolve_dynamic_shape( + # n_qubits, insertion_point + # ) + # measurementOp = quantum.ProbsOp( + # operands=(obs, dynamic_shape, None, None), + # result_types=( + # builtin.TensorType(element_type=builtin.Float64Type(), shape=tensor_shape), + # ), + # ) + + # case measurements.SampleMP: + # # n_qubits or shots may not be known at compile time + # _iter = (self.wire_manager.shots, n_qubits) + # tensor_shape = tuple(-1 if isinstance(i, SSAValue) else i for i in _iter) + # # We only insert values into dynamic_shape that are unknown at compile time + # dynamic_shape = tuple(i for i in _iter if isinstance(i, SSAValue)) + # if isinstance(n_qubits, int) and n_qubits == 1: + # tensor_shape = (tensor_shape[0],) + + # measurementOp = quantum.SampleOp( + # operands=(obs, dynamic_shape, None), + # result_types=( + # builtin.TensorType(element_type=builtin.Float64Type(), shape=tensor_shape), + # ), + # ) + + # case measurements.CountsMP: + # # For now, we assume that there are no input MemRefTypes + # tensor_shape, dynamic_shape, insertion_point = self._resolve_dynamic_shape( + # n_qubits, insertion_point + # ) + # measurementOp = quantum.CountsOp( + # operands=(obs, dynamic_shape, None, None), + # result_types=( + # builtin.TensorType(element_type=builtin.Float64Type(), shape=tensor_shape), + # builtin.TensorType(element_type=builtin.i64, shape=tensor_shape), + # ), + # ) + + # case _: + # raise TransformError( + # f"The measurement {type(mp).__name__} cannot be supported into the module." + # ) + + self.insert_op(meas_op, insertion_point=insertion_point) + # TODO: Finish implementation def swap_gates(self, op1: xOperation, op2: xOperation) -> None: """Swap two operations in the IR. From 50d36d3ed121da717cca98c028450603aa8a11cf Mon Sep 17 00:00:00 2001 From: Mudit Pandey Date: Mon, 1 Dec 2025 17:14:46 -0500 Subject: [PATCH 10/10] Finish insert_measurement --- .../pass_api/pattern_rewriter.py | 182 +++++++++++------- 1 file changed, 109 insertions(+), 73 deletions(-) diff --git a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py index 520d21c263..f3f0da0513 100644 --- a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py +++ b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py @@ -694,9 +694,36 @@ def _insert_sprod( self.insert_op(hamiltonianOp, insertion_point=insertion_point) return hamiltonianOp + def _resolve_dynamic_shape( + self, n_qubits: SSAValue | int, insertion_point: InsertPoint + ) -> tuple[tuple[int], tuple[SSAValue] | None, InsertPoint]: + """Get dynamic shape and output tensor shape for dynamic shape for observables + when number of qubits is not known at compile time.""" + + if isinstance(n_qubits, SSAValue): + # If number of qubits is not known, we indicate the dynamic shape, and + # set the shape of the resulting tensor to (-1,), indicating that the shape is + # not known at compile time. + tensor_shape = (-1,) + # Create dynamic shape, which is (2**num_qubits,) or (1 << num_qubits,) + const1Op = self.create_scalar_constant(1, insertion_point=insertion_point) + leftShiftOp = arith.ShLIOp(const1Op, n_qubits) + self.insert_op(leftShiftOp, insertion_point=InsertPoint.after(const1Op)) + insertion_point = InsertPoint.after(leftShiftOp) + dynamic_shape = (leftShiftOp.results[0],) + else: + tensor_shape = (2**n_qubits,) + dynamic_shape = None + + return tensor_shape, dynamic_shape, insertion_point + # TODO: Improve docstring def insert_measurement( - self, mp: MeasurementProcess, insertion_point: InsertPoint + self, + mp: MeasurementProcess, + insertion_point: InsertPoint, + params: Sequence[SSAValue] | None = None, + insert_active_qubits: bool = False, ) -> xOperation: """Insert a measurement operation into the program.""" if mp.mv: @@ -705,12 +732,12 @@ def insert_measurement( "is currently not supported." ) - meas_op = None + measurementOp = None qreg = None if mp.obs: # Measurement on observable - obs = self.insert_observable(mp.obs, insertion_point) + obs = self.insert_observable(mp.obs, insertion_point, params=params) n_qubits = len(mp.obs.wires) insertion_point = InsertPoint.after(obs) @@ -729,6 +756,14 @@ def insert_measurement( else: # Measurement on all wires qreg = self.get_qreg(insertion_point) + if insert_active_qubits: + new_qreg = self.insert_active_qubits(insertion_point, qreg=qreg) + # If a new qreg is returned, that means that quantum.InsertOps were inserted + # into the program. So, we need to update the insertion point. + if new_qreg != qreg: + insertion_point = InsertPoint.after(new_qreg.owner) + qreg = new_qreg + obs = quantum.ComputationalBasisOp( operands=((), qreg), result_types=(quantum.ObservableType(),) ) @@ -736,76 +771,77 @@ def insert_measurement( insertion_point = InsertPoint.after(obs) n_qubits = self.get_num_qubits(InsertPoint.before(obs)) - # # Create the measurement xDSL operation - # match type(mp): - # case measurements.ExpectationMP: - # measurementOp = quantum.ExpvalOp(obs=obs) - - # case measurements.VarianceMP: - # measurementOp = quantum.VarianceOp(obs=obs) - - # case measurements.StateMP: - # # For now, we assume that there is no input MemRefType - # tensor_shape, dynamic_shape, insertion_point = self._resolve_dynamic_shape( - # n_qubits, insertion_point - # ) - # measurementOp = quantum.StateOp( - # operands=(obs, dynamic_shape, None), - # result_types=( - # builtin.TensorType( - # element_type=builtin.ComplexType(builtin.Float64Type()), - # shape=tensor_shape, - # ) - # ), - # ) - - # case measurements.ProbabilityMP: - # # For now, we assume that there is no input MemRefType - # tensor_shape, dynamic_shape, insertion_point = self._resolve_dynamic_shape( - # n_qubits, insertion_point - # ) - # measurementOp = quantum.ProbsOp( - # operands=(obs, dynamic_shape, None, None), - # result_types=( - # builtin.TensorType(element_type=builtin.Float64Type(), shape=tensor_shape), - # ), - # ) - - # case measurements.SampleMP: - # # n_qubits or shots may not be known at compile time - # _iter = (self.wire_manager.shots, n_qubits) - # tensor_shape = tuple(-1 if isinstance(i, SSAValue) else i for i in _iter) - # # We only insert values into dynamic_shape that are unknown at compile time - # dynamic_shape = tuple(i for i in _iter if isinstance(i, SSAValue)) - # if isinstance(n_qubits, int) and n_qubits == 1: - # tensor_shape = (tensor_shape[0],) - - # measurementOp = quantum.SampleOp( - # operands=(obs, dynamic_shape, None), - # result_types=( - # builtin.TensorType(element_type=builtin.Float64Type(), shape=tensor_shape), - # ), - # ) - - # case measurements.CountsMP: - # # For now, we assume that there are no input MemRefTypes - # tensor_shape, dynamic_shape, insertion_point = self._resolve_dynamic_shape( - # n_qubits, insertion_point - # ) - # measurementOp = quantum.CountsOp( - # operands=(obs, dynamic_shape, None, None), - # result_types=( - # builtin.TensorType(element_type=builtin.Float64Type(), shape=tensor_shape), - # builtin.TensorType(element_type=builtin.i64, shape=tensor_shape), - # ), - # ) - - # case _: - # raise TransformError( - # f"The measurement {type(mp).__name__} cannot be supported into the module." - # ) - - self.insert_op(meas_op, insertion_point=insertion_point) + # Create the measurement xDSL operation + match type(mp): + case measurements.ExpectationMP: + measurementOp = quantum.ExpvalOp(obs=obs) + + case measurements.VarianceMP: + measurementOp = quantum.VarianceOp(obs=obs) + + case measurements.StateMP: + # For now, we assume that there is no input MemRefType + tensor_shape, dynamic_shape, insertion_point = self._resolve_dynamic_shape( + n_qubits, insertion_point + ) + measurementOp = quantum.StateOp( + operands=(obs, dynamic_shape, None), + result_types=( + builtin.TensorType( + element_type=builtin.ComplexType(builtin.Float64Type()), + shape=tensor_shape, + ) + ), + ) + + case measurements.ProbabilityMP: + # For now, we assume that there is no input MemRefType + tensor_shape, dynamic_shape, insertion_point = self._resolve_dynamic_shape( + n_qubits, insertion_point + ) + measurementOp = quantum.ProbsOp( + operands=(obs, dynamic_shape, None, None), + result_types=( + builtin.TensorType(element_type=builtin.Float64Type(), shape=tensor_shape), + ), + ) + + case measurements.SampleMP: + # n_qubits or shots may not be known at compile time + _iter = (self.wire_manager.shots, n_qubits) + tensor_shape = tuple(-1 if isinstance(i, SSAValue) else i for i in _iter) + # We only insert values into dynamic_shape that are unknown at compile time + dynamic_shape = tuple(i for i in _iter if isinstance(i, SSAValue)) + if isinstance(n_qubits, int) and n_qubits == 1: + tensor_shape = (tensor_shape[0],) + + measurementOp = quantum.SampleOp( + operands=(obs, dynamic_shape, None), + result_types=( + builtin.TensorType(element_type=builtin.Float64Type(), shape=tensor_shape), + ), + ) + + case measurements.CountsMP: + # For now, we assume that there are no input MemRefTypes + tensor_shape, dynamic_shape, insertion_point = self._resolve_dynamic_shape( + n_qubits, insertion_point + ) + measurementOp = quantum.CountsOp( + operands=(obs, dynamic_shape, None, None), + result_types=( + builtin.TensorType(element_type=builtin.Float64Type(), shape=tensor_shape), + builtin.TensorType(element_type=builtin.i64, shape=tensor_shape), + ), + ) + + case _: + raise TransformError( + f"The measurement {type(mp).__name__} cannot be supported into the module." + ) + + self.insert_op(measurementOp, insertion_point=insertion_point) + return measurementOp # TODO: Finish implementation def swap_gates(self, op1: xOperation, op2: xOperation) -> None: