diff --git a/frontend/catalyst/python_interface/pass_api/cancel_inverses_prototype.py b/frontend/catalyst/python_interface/pass_api/cancel_inverses_prototype.py index 1983fd384c..38cbee766e 100644 --- a/frontend/catalyst/python_interface/pass_api/cancel_inverses_prototype.py +++ b/frontend/catalyst/python_interface/pass_api/cancel_inverses_prototype.py @@ -20,6 +20,7 @@ from .compiler_transform import compiler_transform from .passes import PLModulePass +from .pattern_rewriter import PLPatternRewriter class CancelInverses(PLModulePass): @@ -58,7 +59,7 @@ def can_cancel(op: quantum.CustomOp, next_op: Operation) -> bool: @CancelInverses.rewrite_rule(quantum.CustomOp) -def rewrite_custom_op(self, op, rewriter): +def rewrite_custom_op(self, op: quantum.CustomOp, rewriter: PLPatternRewriter): """Rewrite rule for CustomOp.""" while isinstance(op, quantum.CustomOp) and op.gate_name.data in self.self_inverses: next_user = None @@ -71,31 +72,27 @@ def rewrite_custom_op(self, op, rewriter): if next_user is None: break - for q1, q2 in zip(op.in_qubits, next_user.out_qubits, strict=True): - rewriter.replace_all_uses_with(q2, q1) - for cq1, cq2 in zip(op.in_ctrl_qubits, next_user.out_ctrl_qubits, strict=True): - rewriter.replace_all_uses_with(cq2, cq1) - rewriter.erase_op(next_user) - rewriter.erase_op(op) + rewriter.erase_gate(next_user) + rewriter.erase_gate(op) op = op.in_qubits[0].owner # We can register more rewrite rules as needed. Here are some # dummy rewrite rules to illustrate: @CancelInverses.rewrite_rule(quantum.InsertOp) -def rewrite_insert_op(self, op, rewriter): +def rewrite_insert_op(self, op: quantum.InsertOp, rewriter: PLPatternRewriter): """Rewrite rule for InsertOp.""" return @CancelInverses.rewrite_rule(quantum.ExtractOp) -def rewrite_extract_op(self, op, rewriter): +def rewrite_extract_op(self, op: quantum.ExtractOp, rewriter: PLPatternRewriter): """Rewrite rule for ExtractOp.""" return @CancelInverses.rewrite_rule(quantum.MeasureOp) -def rewrite_mid_measure_op(self, op, rewriter): +def rewrite_mid_measure_op(self, op: quantum.MeasureOp, rewriter: PLPatternRewriter): """Rewrite rule for MeasureOp.""" return diff --git a/frontend/catalyst/python_interface/pass_api/passes.py b/frontend/catalyst/python_interface/pass_api/passes.py index dee22e3022..4fd0c1e73f 100644 --- a/frontend/catalyst/python_interface/pass_api/passes.py +++ b/frontend/catalyst/python_interface/pass_api/passes.py @@ -24,12 +24,12 @@ from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, - PatternRewriter, - PatternRewriteWalker, RewritePattern, op_type_rewrite_pattern, ) +from .pattern_rewriter import PLPatternRewriter, PLPatternRewriteWalker + def _update_type_hints(hint: type[Operation] | type[Operation]) -> Callable: """Update the signature of a ``match_and_rewrite`` method to use the provided operation @@ -77,7 +77,7 @@ def __init__(self, _pass): @op_type_rewrite_pattern @_update_type_hints(hint) - def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter) -> None: + def match_and_rewrite(self, op: Operation, rewriter: PLPatternRewriter) -> None: rewrite_rule(self._pass, op, rewriter) return _RewritePattern @@ -120,7 +120,7 @@ def rewrite_rule( .. code-block:: python @PLModulePass.rewrite_rule(MyOperation) - def rewrite_myop(self, op: MyOperation, rewriter: PatternRewriter) -> None: + def rewrite_myop(self, op: MyOperation, rewriter: PLPatternRewriter) -> None: ... .. note:: @@ -135,7 +135,7 @@ def rewrite_myop(self, op: MyOperation, rewriter: PatternRewriter) -> None: Callable: a decorator to register the rewrite rule with the ModulePass """ - def decorator(rule: Callable[[Operation, PatternRewriter], None]) -> Callable: + def decorator(rule: Callable[[Operation, PLPatternRewriter], None]) -> Callable: rewrite_pattern = _create_rewrite_pattern(hint, rule) cls._rewrite_patterns[hint] = rewrite_pattern return rule @@ -170,10 +170,10 @@ def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: # pylint: disable= pattern = GreedyRewritePatternApplier( rewrite_patterns=[rp(self) for rp in self._rewrite_patterns.values()] ) - walker = PatternRewriteWalker(pattern=pattern, apply_recursively=self.recursive) + walker = PLPatternRewriteWalker(pattern=pattern, apply_recursively=self.recursive) walker.rewrite_module(op) else: for rp in self._rewrite_patterns.values(): - walker = PatternRewriteWalker(pattern=rp(self), apply_recursively=self.recursive) + walker = PLPatternRewriteWalker(pattern=rp(self), apply_recursively=self.recursive) walker.rewrite_module(op) diff --git a/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py new file mode 100644 index 0000000000..f3f0da0513 --- /dev/null +++ b/frontend/catalyst/python_interface/pass_api/pattern_rewriter.py @@ -0,0 +1,936 @@ +# Copyright 2025 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pattern rewriter API for quantum compilation passes.""" + +from collections.abc import Sequence +from dataclasses import dataclass, field +from functools import singledispatchmethod +from numbers import Number +from typing import TypeAlias +from uuid import UUID, uuid4 + +from pennylane import math, measurements, ops +from pennylane.exceptions import TransformError +from pennylane.measurements import MeasurementProcess +from pennylane.operation import Operator +from xdsl.builder import ImplicitBuilder +from xdsl.dialects import arith, builtin, func, scf, tensor +from xdsl.ir import Block, BlockArgument +from xdsl.ir import Operation as xOperation +from xdsl.ir import Region, SSAValue +from xdsl.pattern_rewriter import PatternRewriter, PatternRewriterListener, PatternRewriteWalker +from xdsl.rewriter import InsertPoint + +from catalyst.python_interface.dialects import quantum +from catalyst.python_interface.utils import get_constant_from_ssa + +_named_observables = (ops.PauliX, ops.PauliY, ops.PauliZ, ops.Identity, ops.Hadamard) +_gate_like_ops = ( + quantum.CustomOp, + quantum.GlobalPhaseOp, + quantum.MultiRZOp, + quantum.PCPhaseOp, + quantum.QubitUnitaryOp, +) + + +@dataclass +class DynamicWire: + """Dynamic wire label wrapper.""" + + var: SSAValue[builtin.i64] | None = None + """SSA value representing the wire's dynamic index into the quantum register.""" + + _id: UUID = field(default_factory=uuid4, init=False) + """Unique ID representing the qubit. This is only used if the wire was allocated dynamically + (``quantum.AllocQubitOp``) or the qubit originated as a ``BlockArgument``.""" + + def __hash__(self) -> int: + if self.var: + return hash(self.var) + + return hash(self._id) + + def __eq__(self, other: "DynamicWire") -> bool: + if self.var and other.var: + return self.var == other.var + + return self._id == other._id + + +WireLabelLike: TypeAlias = int | DynamicWire + + +class PLPatternRewriter(PatternRewriter): + """A ``PatternRewriter`` with abstractions for quantum compilation passes. + + This is a subclass of ``xdsl.pattern_rewriter.PatternRewriter`` that exposes + methods to abstract away low-level pattern-rewriting details relevant to + quantum compilation passes. + """ + + def __init__(self, current_operation: xOperation): + super().__init__(current_operation) + + def get_qnode( + self, start_op: xOperation | None = None, get_func: bool = False + ) -> builtin.ModuleOp | func.FuncOp: + """Get the module corresponding to the QNode containing the given operation. + + The input operation, or the operation used to initialize the rewriter if an operation + is not provided, are used to search outer scopes until the module corresponding + to a QNode is found. + + .. note:: + + This method assumes that the module corresponding to a QNode will not contain any + modules in its body. + + Args: + start_op (xdsl.ir.Operation): The operation used to begin the search. If ``None``, + the operation used to initialize the rewriter will be used for the search. + get_func (bool): If ``True``, the FuncOp corresponding to the QNode will be returned + instead of the module. ``False`` by default. + + Returns: + ModuleOp | FuncOp: The QNode module surrounding the current operation, or the QNode + function if ``get_func`` is ``True``. + """ + current_op: xOperation = start_op or self.current_operation + while not isinstance(current_op, builtin.ModuleOp): + current_op = current_op.parent_op() + + qnode_func = None + for op in current_op.body.ops(): + if isinstance(op, func.FuncOp) and op.attributes.get("qnode", None): + qnode_func = op + break + + if qnode_func is None: + raise TransformError(f"{current_op} is not inside a QNode's scope.") + + return qnode_func if get_func else current_op + + def get_qreg(self, insertion_point: InsertPoint) -> quantum.QuregSSAValue | None: + """Get the active quantum register at a given point in the program. Returns ``None`` + If there are no active quantum registers. Currently, we assume that there can only be + at most one active quantum register at any given point in the program.""" + cur_op = insertion_point.insert_before.prev_op + qreg = None + + while not (cur_op is None or isinstance(cur_op, quantum.DeallocOp)): + for r in cur_op.results: + if isinstance(r.type, quantum.QuregType): + qreg = r + break + + return qreg + + def _find_qubit_root( + self, qubit: quantum.QubitSSAValue + ) -> tuple[quantum.ExtractOp | quantum.AllocQubitOp | Block, list[quantum.QubitSSAValue]]: + """Find the root operation that created a qubit.""" + root = None + owner = qubit.owner + seen_qubits = set() + + while True: + seen_qubits.add(qubit) + + if isinstance(owner, (Block, quantum.AllocQubitOp, quantum.ExtractOp)): + root = owner + break + + if isinstance(owner, _gate_like_ops): + qb_idx = qubit.index + n_classical_operands = ( + len(owner.operands) + - len(owner.in_ctrl_qubits) + - len(getattr(owner, "in_qubits", [])) + ) + qubit = owner.operands[n_classical_operands + qb_idx] + owner = qubit.owner + + else: + raise ValueError(f"Cannot handle {owner} type yet for finding a qubit's root.") + + return root, seen_qubits + + # TODO: Improve docstring + def get_idx_from_qubit(self, qubit: quantum.QubitSSAValue) -> int | SSAValue | None: + """Get the index to the quantum register to which a qubit corresponds. + + If the index is unknown at compile-time, an SSA value corresponding to the index + will be returned. If the qubit was dynamically allocated, or originated from a + block argument, ``None`` will be returned. + """ + if not isinstance(qubit.type, quantum.QubitType): + raise TypeError(f"The input must be a QubitType SSAValue, got {qubit}.") + + root, _ = self._find_qubit_root(qubit) + + if isinstance(root, quantum.ExtractOp): + if (idx_attr := getattr(root, "idx_attr", None)) is not None: + return idx_attr.value.data + elif (idx := get_constant_from_ssa(root.idx)) is not None: + return idx + + return root.idx + + # If the qubit's root is an ``AllocQubitOp``, the qubit does not correspond to a + # qreg index. Else, it originates from a ``BlockArgument``, in which case, its + # corresponding index, if there is one, cannot be inferred. + return None + + # TODO: Improve docstring + def get_active_qubits( + self, insertion_point: InsertPoint + ) -> dict[WireLabelLike | quantum.QubitSSAValue, quantum.QubitSSAValue | WireLabelLike]: + """Get a bidirectional map from wire indices to active qubits corresponding to + those wire labels and vice versa.""" + # Starting constraint: deal with flat regions + qubit_map: dict[int | DynamicWire, quantum.QubitSSAValue] = {} + seen_qubits: set[quantum.QubitSSAValue] = set() + cur_op: xOperation = insertion_point.insert_before + + while not (isinstance(cur_op.prev_op, quantum.AllocOp) or cur_op.prev_op is None): + cur_op = cur_op.prev_op + + if isinstance(cur_op, quantum.InsertOp): + seen_qubits |= self._find_qubit_root(cur_op.qubit)[1] + continue + + for r in cur_op.results: + if not isinstance(r.type, quantum.QubitType) or r in seen_qubits: + continue + + root, cur_seen_qubits = self._find_qubit_root(r) + seen_qubits |= cur_seen_qubits + insert_idx = None + + if isinstance(root, quantum.ExtractOp): + if (idx_attr := getattr(root, "idx_attr", None)) is not None: + insert_idx = idx_attr.value.data + elif (idx := get_constant_from_ssa(root.idx)) is not None: + insert_idx = idx + else: + insert_idx = DynamicWire(var=root.idx) + + # Qubit's root is either AllocQubitOp or a Block. Either way, + # the qubit does not correspond to an index. + else: + insert_idx = DynamicWire() + + qubit_map[insert_idx] = r + qubit_map[r] = insert_idx + + return qubit_map + + def insert_active_qubits( + self, + insertion_point: InsertPoint, + qreg: quantum.QuregSSAValue | None, + active_qubits: dict | None = None, + ) -> quantum.QuregSSAValue: + """Insert all currently active qubits into the quantum register.""" + if active_qubits is None: + active_qubits = self.get_active_qubits(insertion_point) + if qreg is None: + qreg = self.get_qreg(insertion_point) + + for idx, qubit in active_qubits.items(): + if isinstance(idx, (int, DynamicWire)): + if isinstance(idx, DynamicWire) and idx.var is not None: + idx = idx.var + + insertOp = quantum.InsertOp(in_qreg=qreg, idx=idx, qubit=qubit) + self.insert_op(insertOp, insertion_point=insertion_point) + insertion_point = InsertPoint.after(insertOp) + qreg = insertOp.results[0] + + return qreg + + def insert_constant(self, cst: Number, insertion_point: InsertPoint) -> xOperation: + """Create a scalar ConstantOp and insert it into the IR. + + Args: + cst (Number): The scalar to insert into the IR + insertion_point (InsertPoint): The point in the IR where the ``ConstantOp`` + that creates the constant SSAValue should be inserted + + Returns: + SSAValue: The SSA value corresponding to the constant + """ + data = [cst] + match cst: + case int(): + elem_type = builtin.IntegerType(64) + case float(): + elem_type = builtin.Float64Type() + case cst, bool(): + elem_type = builtin.IntegerType(1) + case complex(): + elem_type = builtin.ComplexType() + data = [[cst.real, cst.imag]] + case _: + raise TypeError(f"{cst} is not a valid type to insert as a constant.") + + type_ = builtin.TensorType(element_type=elem_type, shape=[]) + constAttr = builtin.DenseIntOrFPElementsAttr.from_list(type_, data) + constantOp = arith.ConstantOp(constAttr) + extractOp = tensor.ExtractOp(tensor=constantOp.result, indices=[], result_type=elem_type) + + self.insert_op(constantOp, insertion_point) + self.insert_op(extractOp, InsertPoint.after(constantOp)) + + return extractOp + + def get_num_qubits(self, insertion_point: InsertPoint) -> SSAValue[builtin.I64]: + """Get the number of qubits. + + All qubits available in the QNode, whether statically or dynamically allocated, + at a given point in the QNode are returned as a 64-bit integer SSAValue. + + Args: + insertion_point (InsertPoint): The point in the IR where the instruction + that gets the number of qubits should be inserted. This is necessary + because the number of qubits in a program can be different at different + points in the program when dynamically allocated qubits are present. + + Returns: + SSAValue[I64]: A 64-bit integer SSAValue corresponding to the number of allocated + qubits. + """ + numQubitsOp = quantum.NumQubitsOp() + self.insert_op(numQubitsOp, insertion_point=insertion_point) + return numQubitsOp.results[0] + + def get_num_shots(self, as_literal: bool = False) -> SSAValue[builtin.I64] | int | None: + """Get the number of shots. + + Args: + as_literal (bool): If ``True``, the shots will be returned as a Python + integer. If the shots are dynamic, the returned value will be -1. + If ``False``, an int64 ``SSAValue`` corresponding to the number of shots + will be returned. False by default + + Returns: + SSAValue[I64] | int | None: ``int`` if ``is_literal`` is ``True``, else an int64 + ``SSAValue``. If the execution is analytic, ``None`` will be returned. + """ + try: + qnode: func.FuncOp = self.get_qnode(get_func=True) + except TransformError as e: + raise TransformError( + "Cannot get the number of shots when rewriting an operation outside the " + "scope of a QNode." + ) from e + + # The qnode function always initializes a quantum device using the quantum.DeviceInitOp + # operation. + device_init = None + for op in qnode.body.ops: + if isinstance(op, quantum.DeviceInitOp): + device_init = op + break + + assert device_init is not None + + # If the device is **known** to be analytic, it will not have any operands. Note that even + # if the DeviceInitOp has shots as its operand, it may be analytic if the shots operand is + # a constant == 0. + if len(device_init.operands) == 0: + return None + + shots = device_init.operands[0] + if not as_literal: + return shots + + # If shots are dynamic, they will **always** be an argument to the QNode, else + # they will be static, and created using a constant-like operation. + if isinstance(shots, BlockArgument): + return -1 + + shots = get_constant_from_ssa(shots) + return None if shots == 0 else shots + + def erase_gate(self, op: xOperation) -> None: + """Erase a quantum gate. + + Safely erase a quantum gate from the module being transformed. This method automatically + handles and pre-processing required before safely erasing an operation. To erase quantum + gates, which include ``CustomOp``, ``GlobalPhaseOp``, ``MultiRZOp``, ``PCPhaseOp``, and + ``QubitUnitaryOp``, it is recommended to use this method instead of ``erase_op``. + + Args: + op (xdsl.ir.Operation): The operation to erase + """ + if not isinstance(op, _gate_like_ops): + raise TypeError( + f"Cannot erase {op}. 'PLPatternRewriter.erase_op' can only erase " + "gate-like operations." + ) + + # GlobalPhaseOp does not have any target qubits + in_qubits = ( + op.in_ctrl_qubits + if isinstance(op, quantum.GlobalPhaseOp) + else (op.in_qubits + op.in_ctrl_qubits) + ) + self.replace_op(op, (), in_qubits) + + def insert_mid_measure( + self, mcm: measurements.MidMeasureMP, insertion_point: InsertPoint + ) -> SSAValue[quantum.QubitType]: + """Insert a PL mid-circuit measurement into the IR at the provided insertion point. + + Args: + mcm (pennylane.ops.MidMeasureMP): The mid-circuit measurement to insert into + the IR. Note that the measurement qubit must be an SSAValue. + insertion_point (InsertPoint): The point in the IR where the operation must + be inserted. + + Returns: + xdsl.ir.SSAValue[quantum.QubitType]: The qubit returned by the mid-circuit measurement. + """ + in_qubit: SSAValue[quantum.QubitType] = mcm.wires[0] + midMeasureOp = quantum.MeasureOp(in_qubit=in_qubit, postselect=mcm.postselect) + self.insert_op(midMeasureOp, insertion_point=insertion_point) + out_qubit = midMeasureOp.out_qubit + + # If resetting, we need to insert a conditional statement that applies a PauliX + # if we measured |1>. The else block just yields a qubit. + if mcm.reset: + true_region = Region() + with ImplicitBuilder(true_region): + gate = quantum.CustomOp(gate_name="PauliX", in_qubits=(midMeasureOp.out_qubit,)) + _ = scf.YieldOp(gate.out_qubits[0]) + + false_region = Region() + with ImplicitBuilder(false_region): + _ = scf.YieldOp(midMeasureOp.out_qubit) + + ifOp = scf.IfOp( + cond=midMeasureOp.mres, + return_types=(quantum.QubitType(),), + true_region=true_region, + false_region=false_region, + ) + self.insert_op(ifOp, InsertPoint.after(midMeasureOp)) + out_qubit = ifOp.results[0] + + mcm_qubit_uses = [use for use in in_qubit.uses if use.operation != midMeasureOp] + in_qubit.replace_by_if(out_qubit, lambda use: use.operation != midMeasureOp) + for use in mcm_qubit_uses: + self.notify_op_modified(use.operation) + + return out_qubit + + @staticmethod + def _get_gate_base(gate: Operator): + """Get the base op of a gate.""" + + if isinstance(gate, ops.Controlled): + base_gate, ctrl_wires, ctrl_vals, adjoint = PLPatternRewriter._get_gate_base(gate.base) + ctrl_wires = tuple(gate.control_wires) + tuple(ctrl_wires) + ctrl_vals = tuple(gate.control_values) + tuple(ctrl_vals) + return base_gate, ctrl_wires, ctrl_vals, adjoint + + if isinstance(gate, ops.Adjoint): + base_gate, ctrl_wires, ctrl_vals, adjoint = PLPatternRewriter._get_gate_base(gate.base) + adjoint = adjoint ^ True + return base_gate, tuple(ctrl_wires), tuple(ctrl_vals), adjoint + + return gate, (), (), False + + # TODO: fix too-many-statements warning + # pylint: disable=too-many-statements, too-many-branches + def insert_gate( + self, gate: Operator, insertion_point: InsertPoint, params: Sequence[SSAValue] | None = None + ) -> xOperation: + r"""Insert a PL gate into the IR at the provided insertion point. + + .. note:: + + Inserting state-preparation operations is currently not supported. + + Args: + gate (~pennylane.operation.Operator): The gate to insert. The wires of the gate must be + ``QubitType`` ``SSAValue``\ s. + insertion_point (InsertPoint): The point where the operation should be inserted. + params (Sequence[SSAValue] | None): For parametric gates, the list of ``SSAValue``\ s that + should be used as the gate's operands. If not provided, the parameters to ``gate`` will + be inserted as constants into the program. + + Returns: + xdsl.ir.Operation: The xDSL operation corresponding to the gate being inserted. + """ + # TODO: Add support for StatePrep, BasisState + gate, ctrl_wires, ctrl_vals, adjoint = self._get_gate_base(gate) + op_args = {} + + # If the gate is a QubitUnitary and an SSA tensor is not provided as its matrix, then + # we need to create a constant matrix SSAValue using the gate's matrix. + if isinstance(gate, ops.QubitUnitary) and not params: + mat = gate.matrix() + mat_attr = builtin.DenseIntOrFPElementsAttr.from_list( + builtin.TensorType( + builtin.ComplexType(builtin.Float64Type()), shape=math.shape(mat) + ), + mat, + ) + constantOp = arith.ConstantOp(value=mat_attr) + self.insert_op(constantOp, insertion_point=insertion_point) + insertion_point = InsertPoint.after(constantOp) + params = [constantOp.results[0]] + + # Create static parameters + elif not params: + params = [] + for d in gate.data: + try: + d = float(d) + except ValueError as e: + raise TransformError( + "Only values that can be cast into floats can be used as gate " + f"parameters. Got {d}." + ) from e + constOp = self.insert_constant(d, insertion_point) + params.append(constOp.results[0]) + insertion_point = InsertPoint.after(constOp) + + # PCPhase has a `dim` hyperparameter which also needs to be inserted into the IR. + if isinstance(gate, ops.PCPhase): + constOp = self.insert_constant( + float(gate.hyperparameters["dimension"][0]), insertion_point + ) + params.append(constOp.results[0]) + insertion_point = InsertPoint.after(constOp) + + # Different gate types may be represented in MLIR by different operations, which may + # take slightly different arguments + match type(gate): + case ops.GlobalPhase: + op_class = quantum.GlobalPhaseOp + assert len(params) == 1 + op_args["params"] = params[0] + case ops.MultiRZ: + op_class = quantum.MultiRZOp + assert len(params) == 1 + op_args["theta"] = params[0] + case ops.QubitUnitary: + op_class = quantum.QubitUnitaryOp + assert len(params) == 1 + op_args["matrix"] = params[0] + case ops.PCPhase: + op_class = quantum.PCPhaseOp + op_args["theta"] = params[0] + op_args["dim"] = params[1] + case _: + op_class = quantum.CustomOp + op_args["gate_name"] = gate.name + op_args["params"] = params + + # Add qubits/control qubits to args. GlobalPhaseOp does not take qubits, only + # control qubits + if not isinstance(gate, ops.GlobalPhase): + op_args["in_qubits"] = tuple(gate.wires) + op_args["in_ctrl_qubits"] = tuple(ctrl_wires) if ctrl_wires else None + in_ctrl_values = None + + # Add ctrl values to args + if ctrl_vals: + true_cst = None + false_cst = None + if any(ctrl_vals): + true_cst = self.insert_constant(True, insertion_point=insertion_point) + insertion_point = InsertPoint.after(true_cst) + if not all(ctrl_vals): + false_cst = self.insert_constant(False, insertion_point=insertion_point) + insertion_point = InsertPoint.after(false_cst) + in_ctrl_values = tuple( + true_cst.results[0] if v else false_cst.results[0] for v in ctrl_vals + ) + op_args["in_ctrl_values"] = in_ctrl_values + op_args["adjoint"] = adjoint + + gateOp = op_class(**op_args) + self.insert_op(gateOp, insertion_point=insertion_point) + + # Use getattr for in/out_qubits because GlobalPhaseOp does not have in/out_qubits + for iq, oq in zip( + getattr(gateOp, "in_qubits", ()) + tuple(gateOp.in_ctrl_qubits), + getattr(gateOp, "out_qubits", ()) + tuple(gateOp.out_ctrl_qubits), + strict=True, + ): + in_qubit_uses = [use for use in iq.uses if use.operation != gateOp] + iq.replace_by_if(oq, lambda use: use.operation != gateOp) + for use in in_qubit_uses: + self.notify_op_modified(use.operation) + + return gateOp + + @singledispatchmethod + def insert_observable( + self, obs: Operator, insertion_point: InsertPoint, params: Sequence[SSAValue] | None = None + ) -> xOperation: + """Insert a PL observable into the IR at the provided insertion point.""" + if isinstance(obs, _named_observables): + in_qubit = obs.wires[0] + namedObsOp = quantum.NamedObsOp( + in_qubit, quantum.NamedObservableAttr(getattr(quantum.NamedObservable, obs.name)) + ) + self.insert_op(namedObsOp, insertion_point=insertion_point) + return namedObsOp + + raise TransformError( + f"The observable {type(obs).__name__} cannot be inserted into the module." + ) + + @insert_observable.register(ops.Hermitian) + def _insert_hermitian( + self, + obs: ops.Hermitian, + insertion_point: InsertPoint, + params: Sequence[SSAValue] | None = None, + ) -> xOperation: + """Insert a qml.Hermitian observable into the IR at the provided insertion point.""" + if params: + mat_op = params[0] + else: + mat = obs.matrix() + mat_attr = builtin.DenseIntOrFPElementsAttr.from_list( + builtin.TensorType( + builtin.ComplexType(builtin.Float64Type()), shape=math.shape(mat) + ), + mat, + ) + mat_op = arith.ConstantOp(value=mat_attr) + self.insert_op(mat_op, insertion_point=insertion_point) + insertion_point = InsertPoint.after(mat_op) + + in_qubits = tuple(obs.wires) + hermitianOp = quantum.HermitianOp( + operands=(mat_op.results[0], in_qubits), result_types=(quantum.ObservableType(),) + ) + self.insert_op(hermitianOp, insertion_point=insertion_point) + return hermitianOp + + @insert_observable.register(ops.Prod) + def _insert_prod( + self, obs: ops.Prod, insertion_point: InsertPoint, params: Sequence[SSAValue] | None = None + ) -> xOperation: + """Insert a qml.ops.Prod observable into the IR at the provided insertion point.""" + operands = [] + for o in obs.operands: + cur_obs = self.insert_observable(o, insertion_point, params=params) + operands.append(cur_obs.results[0]) + insertion_point = InsertPoint.after(cur_obs) + prodOp = quantum.TensorOp(operands=operands, result_types=(quantum.ObservableType(),)) + self.insert_op(prodOp, insertion_point=insertion_point) + return prodOp + + @insert_observable.register(ops.Sum) + def _insert_sum( + self, obs: ops.Sum, insertion_point: InsertPoint, params: Sequence[SSAValue] | None = None + ) -> xOperation: + """Insert a qml.ops.Sum observable into the IR at the provided insertion point.""" + # Create static tensor for coefficients + if params: + coeffs = params[0] + ops = obs.operands + else: + coeffs, ops = obs.terms() + coeffs_attr = builtin.DenseIntOrFPElementsAttr.from_list( + builtin.TensorType(builtin.Float64Type(), shape=(len(coeffs),)), coeffs + ) + constantOp = arith.ConstantOp(value=coeffs_attr) + self.insert_op(constantOp, insertion_point=insertion_point) + insertion_point = InsertPoint.after(constantOp) + coeffs = constantOp.results[0] + + _ops = [] + for o in ops: + cur_obs = self.insert_observable(o, insertion_point, params=params[1:]) + _ops.append(cur_obs.results[0]) + insertion_point = InsertPoint.after(cur_obs) + + hamiltonianOp = quantum.HamiltonianOp( + operands=(coeffs, _ops), result_types=(quantum.ObservableType(),) + ) + return hamiltonianOp + + @insert_observable.register(ops.SProd) + def _insert_sprod( + self, obs: ops.SProd, insertion_point: InsertPoint, params: Sequence[SSAValue] | None = None + ) -> xOperation: + """Insert a qml.ops.SProd observable into the IR at the provided insertion point.""" + if params: + coeff = params[0] + else: + coeffs_attr = builtin.DenseIntOrFPElementsAttr.from_list( + builtin.TensorType(builtin.Float64Type(), shape=(1,)), [obs.scalar] + ) + constantOp = arith.ConstantOp(value=coeffs_attr) + self.insert_op(constantOp, insertion_point=insertion_point) + insertion_point = InsertPoint.after(constantOp) + coeff = constantOp.results[0] + + base_obs = self.insert_observable(obs.base, insertion_point, params=params[1:]) + hamiltonianOp = quantum.HamiltonianOp( + operands=(coeff, (base_obs.results[0],)), result_types=(quantum.ObservableType(),) + ) + self.insert_op(hamiltonianOp, insertion_point=insertion_point) + return hamiltonianOp + + def _resolve_dynamic_shape( + self, n_qubits: SSAValue | int, insertion_point: InsertPoint + ) -> tuple[tuple[int], tuple[SSAValue] | None, InsertPoint]: + """Get dynamic shape and output tensor shape for dynamic shape for observables + when number of qubits is not known at compile time.""" + + if isinstance(n_qubits, SSAValue): + # If number of qubits is not known, we indicate the dynamic shape, and + # set the shape of the resulting tensor to (-1,), indicating that the shape is + # not known at compile time. + tensor_shape = (-1,) + # Create dynamic shape, which is (2**num_qubits,) or (1 << num_qubits,) + const1Op = self.create_scalar_constant(1, insertion_point=insertion_point) + leftShiftOp = arith.ShLIOp(const1Op, n_qubits) + self.insert_op(leftShiftOp, insertion_point=InsertPoint.after(const1Op)) + insertion_point = InsertPoint.after(leftShiftOp) + dynamic_shape = (leftShiftOp.results[0],) + else: + tensor_shape = (2**n_qubits,) + dynamic_shape = None + + return tensor_shape, dynamic_shape, insertion_point + + # TODO: Improve docstring + def insert_measurement( + self, + mp: MeasurementProcess, + insertion_point: InsertPoint, + params: Sequence[SSAValue] | None = None, + insert_active_qubits: bool = False, + ) -> xOperation: + """Insert a measurement operation into the program.""" + if mp.mv: + raise TransformError( + "Inserting measurements that collect statistics on mid-circuit measurements " + "is currently not supported." + ) + + measurementOp = None + qreg = None + + if mp.obs: + # Measurement on observable + obs = self.insert_observable(mp.obs, insertion_point, params=params) + n_qubits = len(mp.obs.wires) + + insertion_point = InsertPoint.after(obs) + + if mp.wires: + # Measurement on some wires + obs = quantum.ComputationalBasisOp( + operands=(tuple(mp.wires), None), + result_types=(quantum.ObservableType()), + ) + n_qubits = len(mp.wires) + self.insert_op(obs, insertion_point=insertion_point) + + insertion_point = InsertPoint.after(obs) + + else: + # Measurement on all wires + qreg = self.get_qreg(insertion_point) + if insert_active_qubits: + new_qreg = self.insert_active_qubits(insertion_point, qreg=qreg) + # If a new qreg is returned, that means that quantum.InsertOps were inserted + # into the program. So, we need to update the insertion point. + if new_qreg != qreg: + insertion_point = InsertPoint.after(new_qreg.owner) + qreg = new_qreg + + obs = quantum.ComputationalBasisOp( + operands=((), qreg), result_types=(quantum.ObservableType(),) + ) + self.insert_op(obs, insertion_point=insertion_point) + insertion_point = InsertPoint.after(obs) + n_qubits = self.get_num_qubits(InsertPoint.before(obs)) + + # Create the measurement xDSL operation + match type(mp): + case measurements.ExpectationMP: + measurementOp = quantum.ExpvalOp(obs=obs) + + case measurements.VarianceMP: + measurementOp = quantum.VarianceOp(obs=obs) + + case measurements.StateMP: + # For now, we assume that there is no input MemRefType + tensor_shape, dynamic_shape, insertion_point = self._resolve_dynamic_shape( + n_qubits, insertion_point + ) + measurementOp = quantum.StateOp( + operands=(obs, dynamic_shape, None), + result_types=( + builtin.TensorType( + element_type=builtin.ComplexType(builtin.Float64Type()), + shape=tensor_shape, + ) + ), + ) + + case measurements.ProbabilityMP: + # For now, we assume that there is no input MemRefType + tensor_shape, dynamic_shape, insertion_point = self._resolve_dynamic_shape( + n_qubits, insertion_point + ) + measurementOp = quantum.ProbsOp( + operands=(obs, dynamic_shape, None, None), + result_types=( + builtin.TensorType(element_type=builtin.Float64Type(), shape=tensor_shape), + ), + ) + + case measurements.SampleMP: + # n_qubits or shots may not be known at compile time + _iter = (self.wire_manager.shots, n_qubits) + tensor_shape = tuple(-1 if isinstance(i, SSAValue) else i for i in _iter) + # We only insert values into dynamic_shape that are unknown at compile time + dynamic_shape = tuple(i for i in _iter if isinstance(i, SSAValue)) + if isinstance(n_qubits, int) and n_qubits == 1: + tensor_shape = (tensor_shape[0],) + + measurementOp = quantum.SampleOp( + operands=(obs, dynamic_shape, None), + result_types=( + builtin.TensorType(element_type=builtin.Float64Type(), shape=tensor_shape), + ), + ) + + case measurements.CountsMP: + # For now, we assume that there are no input MemRefTypes + tensor_shape, dynamic_shape, insertion_point = self._resolve_dynamic_shape( + n_qubits, insertion_point + ) + measurementOp = quantum.CountsOp( + operands=(obs, dynamic_shape, None, None), + result_types=( + builtin.TensorType(element_type=builtin.Float64Type(), shape=tensor_shape), + builtin.TensorType(element_type=builtin.i64, shape=tensor_shape), + ), + ) + + case _: + raise TransformError( + f"The measurement {type(mp).__name__} cannot be supported into the module." + ) + + self.insert_op(measurementOp, insertion_point=insertion_point) + return measurementOp + + # TODO: Finish implementation + def swap_gates(self, op1: xOperation, op2: xOperation) -> None: + """Swap two operations in the IR. + + Args: + op1 (xdsl.ir.Operation): First operation for the swap + op2 (xdsl.ir.Operation): Second operation for the swap + """ + raise NotImplementedError + + # if not (op1 in _gate_like_ops and op2 in _gate_like_ops): + # raise TransformError(f"Can only swap gates. Got {op1}, {op2}") + + # if set(op1.results) | set(op2.operands): + # pass + # elif set(op1.operands) | set(op2.results): + # op1, op2 = op2, op1 + # else: + # raise TransformError("Cannot swap operations that are not SSA neighbours.") + + # # Walk the IR forwards or backwards from the operation with less wires to check if there are + # # any ops that use the same wires as op1 or op2 + # n_vals1 = [r for r in op1.results if isinstance(r.type, quantum.QubitType)] + # n_vals2 = [o for o in op2.operands if isinstance(o.type, quantum.QubitType)] + # # If op1 has less results than op2 has operands, do forward traversal from op1 + # # Else, do backward traversal from op2. + # if n_vals1 <= n_vals2: + # pass + # else: + # pass + + # # Update uses for the swap + # new_op2 = type(op2).create( + # operands=(), + # result_types=(), + # properties=op2.properties, + # attributes=op2.attributes, + # successors=op2.successors, + # regions=op2.regions, + # ) + # self.insert_op(new_op2, insertion_point=InsertPoint.before(op1)) + + # # Reorder the block so that the linear order is valid + + +# pylint: disable=too-few-public-methods +class PLPatternRewriteWalker(PatternRewriteWalker): + """A ``PatternRewriteWalker`` for traversing and rewriting modules. + + This is a subclass of ``xdsl.pattern_rewriter.PatternRewriteWalker that uses a custom + rewriter that contains abstractions for quantum compilation passes.""" + + def _process_worklist(self, listener: PatternRewriterListener) -> bool: + """ + Process the worklist until it is empty. + Returns true if any modification was done. + """ + rewriter_has_done_action = False + + # Handle empty worklist + op = self._worklist.pop() + if op is None: + return rewriter_has_done_action + + # Create a rewriter on the first operation + # Here, we use our custom rewriter instead of the default PatternRewriter. + rewriter = PLPatternRewriter(op) + rewriter.extend_from_listener(listener) + + # do/while loop + while True: + # Reset the rewriter on `op` + rewriter.has_done_action = False + rewriter.current_operation = op + rewriter.insertion_point = InsertPoint.before(op) + rewriter.name_hint = None + + # Apply the pattern on the operation + try: + self.pattern.match_and_rewrite(op, rewriter) + except Exception as err: # pylint: disable=broad-exception-caught + op.emit_error( + f"Error while applying pattern: {err}", + underlying_error=err, + ) + rewriter_has_done_action |= rewriter.has_done_action + + # If the worklist is empty, we are done + op = self._worklist.pop() + if op is None: + return rewriter_has_done_action