Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/bloqade/analysis/validation/nocloning/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from . import impls as impls
from .analysis import NoCloningValidation as NoCloningValidation
215 changes: 215 additions & 0 deletions src/bloqade/analysis/validation/nocloning/analysis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from typing import Any

from kirin import ir
from kirin.analysis import Forward
from kirin.dialects import func
from kirin.ir.exception import ValidationError
from kirin.analysis.forward import ForwardFrame
from kirin.validation.validationpass import ValidationPass

from bloqade.analysis.address import (
Address,
AddressAnalysis,
)
from bloqade.analysis.address.lattice import (
Unknown,
AddressReg,
UnknownReg,
AddressQubit,
PartialIList,
PartialTuple,
UnknownQubit,
)

from .lattice import May, Top, Must, Bottom, QubitValidation


class QubitValidationError(ValidationError):
"""ValidationError for definite (Must) violations with concrete qubit addresses."""

qubit_id: int
gate_name: str

def __init__(self, node: ir.IRNode, qubit_id: int, gate_name: str):
super().__init__(node, f"Qubit[{qubit_id}] cloned at {gate_name} gate.")
self.qubit_id = qubit_id
self.gate_name = gate_name


class PotentialQubitValidationError(ValidationError):
"""ValidationError for potential (May) violations with unknown addresses."""

gate_name: str
condition: str

def __init__(self, node: ir.IRNode, gate_name: str, condition: str):
super().__init__(node, f"Potential cloning at {gate_name} gate{condition}.")
self.gate_name = gate_name
self.condition = condition


class _NoCloningAnalysis(Forward[QubitValidation]):
"""Internal forward analysis for tracking qubit cloning violations."""

keys = ("validate.nocloning",)
lattice = QubitValidation

def __init__(self, dialects):
super().__init__(dialects)
self._address_frame: ForwardFrame[Address] | None = None

def method_self(self, method: ir.Method) -> QubitValidation:
return self.lattice.bottom()

def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidation):
if self._address_frame is None:
addr_analysis = AddressAnalysis(self.dialects)
addr_analysis.initialize()
self._address_frame, _ = addr_analysis.run(method)
return super().run(method, *args, **kwargs)

def eval_fallback(
self, frame: ForwardFrame[QubitValidation], node: ir.Statement
) -> tuple[QubitValidation, ...]:
"""Check for qubit usage violations."""
if not isinstance(node, func.Invoke):
return tuple(Bottom() for _ in node.results)

address_frame = self._address_frame
if address_frame is None:
return tuple(Top() for _ in node.results)

concrete_addrs: list[int] = []
has_unknown = False
has_qubit_args = False
unknown_arg_names: list[str] = []
for arg in node.args:
addr = address_frame.get(arg)
match addr:
case AddressQubit(data=qubit_addr):
has_qubit_args = True
concrete_addrs.append(qubit_addr)
case AddressReg(data=addrs):
has_qubit_args = True
concrete_addrs.extend(addrs)
case (
UnknownQubit()
| UnknownReg()
| PartialIList()
| PartialTuple()
| Unknown()
):
has_qubit_args = True
has_unknown = True
arg_name = self._get_source_name(arg)
unknown_arg_names.append(arg_name)
case _:
pass

if not has_qubit_args:
return tuple(Bottom() for _ in node.results)

seen: set[int] = set()
must_violations: list[str] = []
s_name = getattr(node.callee, "sym_name", "<unknown")
gate_name = s_name.upper()

for qubit_addr in concrete_addrs:
if qubit_addr in seen:
violation = f"Qubit[{qubit_addr}] on {gate_name} Gate"
must_violations.append(violation)
self.add_validation_error(
node, QubitValidationError(node, qubit_addr, gate_name)
)

seen.add(qubit_addr)

if must_violations:
usage = Must(violations=frozenset(must_violations))
elif has_unknown:
args_str = " == ".join(unknown_arg_names)
if len(unknown_arg_names) > 1:
condition = f", when {args_str}"
else:
condition = f", with unknown argument {args_str}"

self.add_validation_error(
node, PotentialQubitValidationError(node, gate_name, condition)
)

usage = May(violations=frozenset([f"{gate_name} Gate{condition}"]))
else:
usage = Bottom()

return tuple(usage for _ in node.results) if node.results else (usage,)

def _get_source_name(self, value: ir.SSAValue) -> str:
"""Trace back to get the source variable name."""
from kirin.dialects.py.indexing import GetItem

if isinstance(value, ir.ResultValue) and isinstance(value.stmt, GetItem):
index_arg = value.stmt.args[1]
return self._get_source_name(index_arg)

if isinstance(value, ir.BlockArgument):
return value.name or f"arg{value.index}"

if hasattr(value, "name") and value.name:
return value.name

return str(value)


class NoCloningValidation(ValidationPass):
"""Validates the no-cloning theorem by tracking qubit addresses."""

def __init__(self):
self._analysis: _NoCloningAnalysis | None = None
self._cached_address_frame = None

def name(self) -> str:
return "No-Cloning Validation"

def get_required_analyses(self) -> list[type]:
"""Declare dependency on AddressAnalysis."""
return [AddressAnalysis]

def set_analysis_cache(self, cache: dict[type, Any]) -> None:
"""Use cached AddressAnalysis result."""
self._cached_address_frame = cache.get(AddressAnalysis)

def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]:
"""Run the no-cloning validation analysis.

Returns:
- frame: ForwardFrame with QubitValidation lattice values
- errors: List of validation errors found
"""
if self._analysis is None:
self._analysis = _NoCloningAnalysis(method.dialects)

self._analysis.initialize()
if self._cached_address_frame is not None:
self._analysis._address_frame = self._cached_address_frame
frame, _ = self._analysis.run(method)
return frame, self._analysis.get_validation_errors()

def print_validation_errors(self):
"""Print all collected errors with formatted snippets."""
if self._analysis is None:
return
validation_errors = self._analysis.get_validation_errors()
for err in validation_errors:
if isinstance(err, QubitValidationError):
print(
f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate"
)
elif isinstance(err, PotentialQubitValidationError):
print(
f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}"
)
else:
print(
f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}"
)
print(err.hint())
86 changes: 86 additions & 0 deletions src/bloqade/analysis/validation/nocloning/impls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from kirin import interp
from kirin.analysis import ForwardFrame
from kirin.dialects import scf

from .lattice import May, Top, Must, Bottom, QubitValidation
from .analysis import (
QubitValidationError,
PotentialQubitValidationError,
_NoCloningAnalysis,
)


@scf.dialect.register(key="validate.nocloning")
class Scf(interp.MethodTable):
@interp.impl(scf.IfElse)
def if_else(
self,
interp_: _NoCloningAnalysis,
frame: ForwardFrame[QubitValidation],
stmt: scf.IfElse,
):
try:
cond_validation = frame.get(stmt.cond)
except Exception:
cond_validation = Top()

errors_before = set(interp_._validation_errors.keys())

with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation)
frame.set_values(then_frame.entries.keys(), then_frame.entries.values())

then_keys = set(interp_._validation_errors.keys()) - errors_before
then_errors = interp_.get_validation_errors(keys=then_keys)

then_state = (
Must(violations=frozenset(err.args[0] for err in then_errors))
if then_keys
else Bottom()
)

if stmt.else_body:
errors_before_else = set(interp_._validation_errors.keys())

with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
interp_.frame_call_region(
else_frame, stmt, stmt.else_body, cond_validation
)
frame.set_values(else_frame.entries.keys(), else_frame.entries.values())

else_keys = set(interp_._validation_errors.keys()) - errors_before_else
else_errors = interp_.get_validation_errors(keys=else_keys)

else_state = (
Must(violations=frozenset(err.args[0] for err in else_errors))
if else_keys
else Bottom()
)
else:
else_state = Bottom()
else_keys = set()
else_errors = []
merged = then_state.join(else_state)
all_branch_keys = then_keys | else_keys
for k in all_branch_keys:
interp_._validation_errors.pop(k, None)

if isinstance(merged, Must):
for err in then_errors + else_errors:
if isinstance(err, QubitValidationError):
interp_.add_validation_error(err.node, err)
elif isinstance(merged, May):
for err in then_errors:
if isinstance(err, QubitValidationError):
potential_err = PotentialQubitValidationError(
err.node, err.gate_name, ", when condition is true"
)
interp_.add_validation_error(err.node, potential_err)

for err in else_errors:
if isinstance(err, QubitValidationError):
potential_err = PotentialQubitValidationError(
err.node, err.gate_name, ", when condition is false"
)
interp_.add_validation_error(err.node, potential_err)
return (merged,)
Loading
Loading