diff --git a/src/bloqade/analysis/validation/nocloning/__init__.py b/src/bloqade/analysis/validation/nocloning/__init__.py new file mode 100644 index 00000000..a61f8ba0 --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/__init__.py @@ -0,0 +1,2 @@ +from . import impls as impls +from .analysis import NoCloningValidation as NoCloningValidation diff --git a/src/bloqade/analysis/validation/nocloning/analysis.py b/src/bloqade/analysis/validation/nocloning/analysis.py new file mode 100644 index 00000000..fc523e41 --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/analysis.py @@ -0,0 +1,215 @@ +from typing import Any + +from kirin import ir +from kirin.analysis import Forward +from kirin.dialects import func +from kirin.ir.exception import ValidationError +from kirin.analysis.forward import ForwardFrame +from kirin.validation.validationpass import ValidationPass + +from bloqade.analysis.address import ( + Address, + AddressAnalysis, +) +from bloqade.analysis.address.lattice import ( + Unknown, + AddressReg, + UnknownReg, + AddressQubit, + PartialIList, + PartialTuple, + UnknownQubit, +) + +from .lattice import May, Top, Must, Bottom, QubitValidation + + +class QubitValidationError(ValidationError): + """ValidationError for definite (Must) violations with concrete qubit addresses.""" + + qubit_id: int + gate_name: str + + def __init__(self, node: ir.IRNode, qubit_id: int, gate_name: str): + super().__init__(node, f"Qubit[{qubit_id}] cloned at {gate_name} gate.") + self.qubit_id = qubit_id + self.gate_name = gate_name + + +class PotentialQubitValidationError(ValidationError): + """ValidationError for potential (May) violations with unknown addresses.""" + + gate_name: str + condition: str + + def __init__(self, node: ir.IRNode, gate_name: str, condition: str): + super().__init__(node, f"Potential cloning at {gate_name} gate{condition}.") + self.gate_name = gate_name + self.condition = condition + + +class _NoCloningAnalysis(Forward[QubitValidation]): + """Internal forward analysis for tracking qubit cloning violations.""" + + keys = ("validate.nocloning",) + lattice = QubitValidation + + def __init__(self, dialects): + super().__init__(dialects) + self._address_frame: ForwardFrame[Address] | None = None + + def method_self(self, method: ir.Method) -> QubitValidation: + return self.lattice.bottom() + + def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidation): + if self._address_frame is None: + addr_analysis = AddressAnalysis(self.dialects) + addr_analysis.initialize() + self._address_frame, _ = addr_analysis.run(method) + return super().run(method, *args, **kwargs) + + def eval_fallback( + self, frame: ForwardFrame[QubitValidation], node: ir.Statement + ) -> tuple[QubitValidation, ...]: + """Check for qubit usage violations.""" + if not isinstance(node, func.Invoke): + return tuple(Bottom() for _ in node.results) + + address_frame = self._address_frame + if address_frame is None: + return tuple(Top() for _ in node.results) + + concrete_addrs: list[int] = [] + has_unknown = False + has_qubit_args = False + unknown_arg_names: list[str] = [] + for arg in node.args: + addr = address_frame.get(arg) + match addr: + case AddressQubit(data=qubit_addr): + has_qubit_args = True + concrete_addrs.append(qubit_addr) + case AddressReg(data=addrs): + has_qubit_args = True + concrete_addrs.extend(addrs) + case ( + UnknownQubit() + | UnknownReg() + | PartialIList() + | PartialTuple() + | Unknown() + ): + has_qubit_args = True + has_unknown = True + arg_name = self._get_source_name(arg) + unknown_arg_names.append(arg_name) + case _: + pass + + if not has_qubit_args: + return tuple(Bottom() for _ in node.results) + + seen: set[int] = set() + must_violations: list[str] = [] + s_name = getattr(node.callee, "sym_name", " 1: + condition = f", when {args_str}" + else: + condition = f", with unknown argument {args_str}" + + self.add_validation_error( + node, PotentialQubitValidationError(node, gate_name, condition) + ) + + usage = May(violations=frozenset([f"{gate_name} Gate{condition}"])) + else: + usage = Bottom() + + return tuple(usage for _ in node.results) if node.results else (usage,) + + def _get_source_name(self, value: ir.SSAValue) -> str: + """Trace back to get the source variable name.""" + from kirin.dialects.py.indexing import GetItem + + if isinstance(value, ir.ResultValue) and isinstance(value.stmt, GetItem): + index_arg = value.stmt.args[1] + return self._get_source_name(index_arg) + + if isinstance(value, ir.BlockArgument): + return value.name or f"arg{value.index}" + + if hasattr(value, "name") and value.name: + return value.name + + return str(value) + + +class NoCloningValidation(ValidationPass): + """Validates the no-cloning theorem by tracking qubit addresses.""" + + def __init__(self): + self._analysis: _NoCloningAnalysis | None = None + self._cached_address_frame = None + + def name(self) -> str: + return "No-Cloning Validation" + + def get_required_analyses(self) -> list[type]: + """Declare dependency on AddressAnalysis.""" + return [AddressAnalysis] + + def set_analysis_cache(self, cache: dict[type, Any]) -> None: + """Use cached AddressAnalysis result.""" + self._cached_address_frame = cache.get(AddressAnalysis) + + def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]: + """Run the no-cloning validation analysis. + + Returns: + - frame: ForwardFrame with QubitValidation lattice values + - errors: List of validation errors found + """ + if self._analysis is None: + self._analysis = _NoCloningAnalysis(method.dialects) + + self._analysis.initialize() + if self._cached_address_frame is not None: + self._analysis._address_frame = self._cached_address_frame + frame, _ = self._analysis.run(method) + return frame, self._analysis.get_validation_errors() + + def print_validation_errors(self): + """Print all collected errors with formatted snippets.""" + if self._analysis is None: + return + validation_errors = self._analysis.get_validation_errors() + for err in validation_errors: + if isinstance(err, QubitValidationError): + print( + f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate" + ) + elif isinstance(err, PotentialQubitValidationError): + print( + f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}" + ) + else: + print( + f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}" + ) + print(err.hint()) diff --git a/src/bloqade/analysis/validation/nocloning/impls.py b/src/bloqade/analysis/validation/nocloning/impls.py new file mode 100644 index 00000000..7dae4d7b --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/impls.py @@ -0,0 +1,86 @@ +from kirin import interp +from kirin.analysis import ForwardFrame +from kirin.dialects import scf + +from .lattice import May, Top, Must, Bottom, QubitValidation +from .analysis import ( + QubitValidationError, + PotentialQubitValidationError, + _NoCloningAnalysis, +) + + +@scf.dialect.register(key="validate.nocloning") +class Scf(interp.MethodTable): + @interp.impl(scf.IfElse) + def if_else( + self, + interp_: _NoCloningAnalysis, + frame: ForwardFrame[QubitValidation], + stmt: scf.IfElse, + ): + try: + cond_validation = frame.get(stmt.cond) + except Exception: + cond_validation = Top() + + errors_before = set(interp_._validation_errors.keys()) + + with interp_.new_frame(stmt, has_parent_access=True) as then_frame: + interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation) + frame.set_values(then_frame.entries.keys(), then_frame.entries.values()) + + then_keys = set(interp_._validation_errors.keys()) - errors_before + then_errors = interp_.get_validation_errors(keys=then_keys) + + then_state = ( + Must(violations=frozenset(err.args[0] for err in then_errors)) + if then_keys + else Bottom() + ) + + if stmt.else_body: + errors_before_else = set(interp_._validation_errors.keys()) + + with interp_.new_frame(stmt, has_parent_access=True) as else_frame: + interp_.frame_call_region( + else_frame, stmt, stmt.else_body, cond_validation + ) + frame.set_values(else_frame.entries.keys(), else_frame.entries.values()) + + else_keys = set(interp_._validation_errors.keys()) - errors_before_else + else_errors = interp_.get_validation_errors(keys=else_keys) + + else_state = ( + Must(violations=frozenset(err.args[0] for err in else_errors)) + if else_keys + else Bottom() + ) + else: + else_state = Bottom() + else_keys = set() + else_errors = [] + merged = then_state.join(else_state) + all_branch_keys = then_keys | else_keys + for k in all_branch_keys: + interp_._validation_errors.pop(k, None) + + if isinstance(merged, Must): + for err in then_errors + else_errors: + if isinstance(err, QubitValidationError): + interp_.add_validation_error(err.node, err) + elif isinstance(merged, May): + for err in then_errors: + if isinstance(err, QubitValidationError): + potential_err = PotentialQubitValidationError( + err.node, err.gate_name, ", when condition is true" + ) + interp_.add_validation_error(err.node, potential_err) + + for err in else_errors: + if isinstance(err, QubitValidationError): + potential_err = PotentialQubitValidationError( + err.node, err.gate_name, ", when condition is false" + ) + interp_.add_validation_error(err.node, potential_err) + return (merged,) diff --git a/src/bloqade/analysis/validation/nocloning/lattice.py b/src/bloqade/analysis/validation/nocloning/lattice.py new file mode 100644 index 00000000..24ce01ae --- /dev/null +++ b/src/bloqade/analysis/validation/nocloning/lattice.py @@ -0,0 +1,175 @@ +from abc import abstractmethod +from typing import FrozenSet, final +from dataclasses import field, dataclass + +from kirin.lattice import SingletonMeta, BoundedLattice + + +@dataclass +class QubitValidation(BoundedLattice["QubitValidation"]): + r"""Base class for qubit-cloning validation lattice. + + Linear ordering (more precise --> less precise): + Bottom ⊑ Must ⊑ May ⊑ Top + + Semantics: + - Bottom: proven safe / never occurs + - Must: definitely occurs (strong) + - May: possibly occurs (weak) + - Top: unknown / no information + """ + + @classmethod + def bottom(cls) -> "QubitValidation": + return Bottom() + + @classmethod + def top(cls) -> "QubitValidation": + return Top() + + @abstractmethod + def is_subseteq(self, other: "QubitValidation") -> bool: ... + + @abstractmethod + def join(self, other: "QubitValidation") -> "QubitValidation": ... + + @abstractmethod + def meet(self, other: "QubitValidation") -> "QubitValidation": ... + + +@final +class Bottom(QubitValidation, metaclass=SingletonMeta): + def is_subseteq(self, other: QubitValidation) -> bool: + return True + + def join(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return self + case Must(violations=v): + return May(violations=v) + case May() | Top(): + return other + return other + + def meet(self, other: QubitValidation) -> QubitValidation: + return self + + def __repr__(self) -> str: + return "⊥ (No Errors)" + + +@final +class Top(QubitValidation, metaclass=SingletonMeta): + def is_subseteq(self, other: QubitValidation) -> bool: + return isinstance(other, Top) + + def join(self, other: QubitValidation) -> QubitValidation: + return self + + def meet(self, other: QubitValidation) -> QubitValidation: + return other + + def __repr__(self) -> str: + return "⊤ (Unknown)" + + +@final +@dataclass +class Must(QubitValidation): + """Definite violations.""" + + violations: FrozenSet[str] = field(default_factory=frozenset) + + def is_subseteq(self, other: QubitValidation) -> bool: + match other: + case Bottom(): + return False + case Must(violations=ov): + return self.violations.issubset(ov) + case May(violations=_): + return True + case Top(): + return True + return False + + def join(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return May(violations=self.violations) + case Must(violations=ov): + if self.violations == ov: + return Must(violations=self.violations) + else: + return May(violations=self.violations | ov) + case May(violations=ov): + return May(violations=self.violations | ov) + case Top(): + return Top() + return Top() + + def meet(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return Bottom() + case Must(violations=ov): + inter = self.violations & ov + return Must(violations=inter) if inter else Bottom() + case May(violations=ov): + inter = self.violations & ov + return Must(violations=inter) if inter else Bottom() + case Top(): + return self + return Bottom() + + def __repr__(self) -> str: + return f"Must({self.violations or '∅'})" + + +@final +@dataclass +class May(QubitValidation): + """Potential violations.""" + + violations: FrozenSet[str] = field(default_factory=frozenset) + + def is_subseteq(self, other: QubitValidation) -> bool: + match other: + case Bottom(): + return False + case Must(): + return False + case May(violations=ov): + return self.violations.issubset(ov) + case Top(): + return True + return False + + def join(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return self + case Must(violations=ov): + return May(violations=self.violations | ov) + case May(violations=ov): + return May(violations=self.violations | ov) + case Top(): + return Top() + return Top() + + def meet(self, other: QubitValidation) -> QubitValidation: + match other: + case Bottom(): + return Bottom() + case Must(violations=ov): + inter = self.violations & ov + return Must(violations=inter) if inter else Bottom() + case May(violations=ov): + inter = self.violations & ov + return May(violations=inter) if inter else Bottom() + case Top(): + return self + return Bottom() + + def __repr__(self) -> str: + return f"May({self.violations or '∅'})" diff --git a/test/analysis/validation/nocloning/test_no_cloning.py b/test/analysis/validation/nocloning/test_no_cloning.py new file mode 100644 index 00000000..2faf9725 --- /dev/null +++ b/test/analysis/validation/nocloning/test_no_cloning.py @@ -0,0 +1,168 @@ +from typing import Any, TypeVar + +import pytest +from kirin import ir +from kirin.dialects.ilist.runtime import IList + +from bloqade import squin +from bloqade.types import Qubit +from bloqade.analysis.validation.nocloning.lattice import May, Must +from bloqade.analysis.validation.nocloning.analysis import ( + NoCloningValidation, + QubitValidationError, + PotentialQubitValidationError, +) + +T = TypeVar("T", bound=Must | May) + + +def collect_errors_from_validation( + validation: NoCloningValidation, +) -> tuple[int, int]: + """Count Must (definite) and May (potential) errors from the validation pass. + + Returns: + (must_count, may_count) - number of definite and potential errors + """ + must_count = 0 + may_count = 0 + + if validation._analysis is None: + return (must_count, may_count) + print(validation._analysis.get_validation_errors()) + for err in validation._analysis.get_validation_errors(): + if isinstance(err, QubitValidationError): + must_count += 1 + elif isinstance(err, PotentialQubitValidationError): + may_count += 1 + + return must_count, may_count + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def bad_control(): + q = squin.qalloc(1) + control_gate(q[0], q[0]) + + validation = NoCloningValidation() + + frame, _ = validation.run(bad_control) + print() + bad_control.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 1 + assert may_count == 0 + validation.print_validation_errors() + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_conditionals_fail(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def bad_control(cond: bool): + q = squin.qalloc(10) + if cond: + control_gate(q[0], q[0]) + else: + control_gate(q[0], q[1]) + squin.cx(q[1], q[1]) + + validation = NoCloningValidation() + frame, _ = validation.run(bad_control) + print() + bad_control.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 1 # squin.cx(q[1], q[1]) outside conditional + assert may_count == 1 # control_gate(q[0], q[0]) inside conditional + validation.print_validation_errors() + + +@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz]) +def test_pass(control_gate: ir.Method[[Qubit, Qubit], Any]): + @squin.kernel + def test(): + q = squin.qalloc(3) + control_gate(q[0], q[1]) + squin.rx(1.57, q[0]) + squin.measure(q[0]) + control_gate(q[0], q[2]) + + validation = NoCloningValidation() + frame, _ = validation.run(test) + print() + test.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 0 + assert may_count == 0 + + +def test_fail_2(): + @squin.kernel + def good_kernel(): + q = squin.qalloc(2) + a = 1 + squin.cx(q[0], q[1]) + squin.cy(q[1], q[a]) + + validation = NoCloningValidation() + frame, _ = validation.run(good_kernel) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 1 + assert may_count == 0 + validation.print_validation_errors() + + +def test_parallel_fail(): + @squin.kernel + def bad_kernel(): + q = squin.qalloc(5) + squin.broadcast.cx(IList([q[0], q[1], q[2]]), IList([q[1], q[2], q[3]])) + + validation = NoCloningValidation() + frame, _ = validation.run(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 2 + assert may_count == 0 + validation.print_validation_errors() + + +def test_potential_fail(): + @squin.kernel + def bad_kernel(a: int, b: int): + q = squin.qalloc(5) + squin.cx(q[a], q[2]) + + validation = NoCloningValidation() + frame, _ = validation.run(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 0 + assert may_count == 1 + validation.print_validation_errors() + + +def test_potential_parallel_fail(): + @squin.kernel + def bad_kernel(a: IList): + q = squin.qalloc(5) + squin.broadcast.cx(a, IList([q[2], q[3], q[4]])) + + validation = NoCloningValidation() + frame, _ = validation.run(bad_kernel) + print() + bad_kernel.print(analysis=frame.entries) + + must_count, may_count = collect_errors_from_validation(validation) + assert must_count == 0 + assert may_count == 1 + validation.print_validation_errors() diff --git a/test/analysis/validation/test_compose_validation.py b/test/analysis/validation/test_compose_validation.py new file mode 100644 index 00000000..86030a74 --- /dev/null +++ b/test/analysis/validation/test_compose_validation.py @@ -0,0 +1,51 @@ +import pytest +from kirin.validation.validationpass import ValidationSuite + +from bloqade import squin +from bloqade.analysis.validation.nocloning import NoCloningValidation + + +def test_validation_suite(): + @squin.kernel + def bad_kernel(a: int): + q = squin.qalloc(2) + squin.cx(q[0], q[0]) # definite cloning error + squin.cx(q[a], q[1]) # potential cloning error + + # Running no-cloning validation multiple times + suite = ValidationSuite( + [ + NoCloningValidation, + NoCloningValidation, + NoCloningValidation, + ] + ) + result = suite.validate(bad_kernel) + + assert not result.is_valid() + assert ( + result.error_count() == 2 + ) # Report 2 errors, even when validated multiple times + print(result.format_errors()) + with pytest.raises(Exception): + result.raise_if_invalid() + + +def test_validation_suite2(): + @squin.kernel + def good_kernel(): + q = squin.qalloc(2) + squin.cx(q[0], q[1]) + + suite = ValidationSuite( + [ + NoCloningValidation, + ], + fail_fast=True, + ) + result = suite.validate(good_kernel) + + assert result.is_valid() + assert result.error_count() == 0 + print(result.format_errors()) + result.raise_if_invalid()