Skip to content

Commit 40f5d6c

Browse files
committed
Refactor no-cloning validation: enhance error handling and improve test coverage
1 parent dca2422 commit 40f5d6c

File tree

4 files changed

+264
-99
lines changed

4 files changed

+264
-99
lines changed

src/bloqade/analysis/validation/nocloning/analysis.py

Lines changed: 110 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,24 +10,43 @@
1010
Address,
1111
AddressAnalysis,
1212
)
13-
from bloqade.analysis.address.lattice import AddressReg, AddressQubit
13+
from bloqade.analysis.address.lattice import (
14+
Unknown,
15+
AddressReg,
16+
UnknownReg,
17+
AddressQubit,
18+
PartialIList,
19+
PartialTuple,
20+
UnknownQubit,
21+
)
1422

15-
from .lattice import QubitValidation
23+
from .lattice import May, Top, Must, Bottom, QubitValidation
1624

1725

1826
class QubitValidationError(ValidationError):
19-
"""ValidationError that records which qubit and gate caused the violation."""
27+
"""ValidationError for definite (Must) violations with concrete qubit addresses."""
2028

2129
qubit_id: int
2230
gate_name: str
2331

2432
def __init__(self, node: ir.IRNode, qubit_id: int, gate_name: str):
25-
# message stored in ValidationError so formatting/hint() will include it
2633
super().__init__(node, f"Qubit[{qubit_id}] cloned at {gate_name} gate.")
2734
self.qubit_id = qubit_id
2835
self.gate_name = gate_name
2936

3037

38+
class PotentialQubitValidationError(ValidationError):
39+
"""ValidationError for potential (May) violations with unknown addresses."""
40+
41+
gate_name: str
42+
condition: str
43+
44+
def __init__(self, node: ir.IRNode, gate_name: str, condition: str):
45+
super().__init__(node, f"Potential cloning at {gate_name} gate{condition}.")
46+
self.gate_name = gate_name
47+
self.condition = condition
48+
49+
3150
class NoCloningValidation(Forward[QubitValidation]):
3251
"""
3352
Validates the no-cloning theorem by tracking qubit addresses.
@@ -40,9 +59,7 @@ class NoCloningValidation(Forward[QubitValidation]):
4059
_address_frame: ForwardFrame[Address] = field(init=False)
4160
_type_frame: ForwardFrame = field(init=False)
4261
method: ir.Method
43-
_validation_errors: list[QubitValidationError] = field(
44-
default_factory=list, init=False
45-
)
62+
_validation_errors: list[ValidationError] = field(default_factory=list, init=False)
4663

4764
def __init__(self, mtd: ir.Method):
4865
"""
@@ -88,68 +105,126 @@ def eval_stmt_fallback(
88105
) -> tuple[QubitValidation, ...]:
89106
"""
90107
Default statement evaluation: check for qubit usage violations.
108+
Returns Bottom, May, Must, or Top depending on what we can prove.
91109
"""
92110

93111
if not isinstance(stmt, func.Invoke):
94-
return tuple(QubitValidation.bottom() for _ in stmt.results)
112+
return tuple(Bottom() for _ in stmt.results)
95113

96114
address_frame = self._address_frame
97115
if address_frame is None:
98-
return tuple(QubitValidation.top() for _ in stmt.results)
116+
return tuple(Top() for _ in stmt.results)
99117

100-
has_qubit_args = any(
101-
isinstance(address_frame.get(arg), (AddressQubit, AddressReg))
102-
for arg in stmt.args
103-
)
118+
concrete_addrs: list[int] = []
119+
has_unknown = False
120+
has_qubit_args = False
121+
unknown_arg_names: list[str] = []
104122

105-
if not has_qubit_args:
106-
return tuple(QubitValidation.bottom() for _ in stmt.results)
107-
108-
used_addrs: list[int] = []
109123
for arg in stmt.args:
110124
addr = address_frame.get(arg)
111-
qubit_addrs = self.get_qubit_addresses(addr)
112-
used_addrs.extend(qubit_addrs)
125+
match addr:
126+
case AddressQubit(data=qubit_addr):
127+
has_qubit_args = True
128+
concrete_addrs.append(qubit_addr)
129+
case AddressReg(data=addrs):
130+
has_qubit_args = True
131+
concrete_addrs.extend(addrs)
132+
case UnknownQubit() | UnknownReg() | Unknown():
133+
has_qubit_args = True
134+
has_unknown = True
135+
arg_name = self._get_source_name(arg)
136+
unknown_arg_names.append(arg_name)
137+
case _:
138+
pass
139+
140+
if not has_qubit_args:
141+
return tuple(Bottom() for _ in stmt.results)
113142

114143
seen: set[int] = set()
115-
violations: list[str] = []
144+
must_violations: list[str] = []
145+
gate_name = stmt.callee.sym_name.upper()
116146

117-
for qubit_addr in used_addrs:
147+
for qubit_addr in concrete_addrs:
118148
if qubit_addr in seen:
119-
gate_name = stmt.callee.sym_name.upper()
120-
violations.append(self.format_violation(qubit_addr, gate_name))
149+
violation = self.format_violation(qubit_addr, gate_name)
150+
must_violations.append(violation)
121151
self._validation_errors.append(
122152
QubitValidationError(stmt, qubit_addr, gate_name)
123153
)
124154
seen.add(qubit_addr)
125155

126-
if not violations:
127-
return tuple(QubitValidation(violations=frozenset()) for _ in stmt.results)
156+
if must_violations:
157+
usage = Must(violations=frozenset(must_violations))
158+
elif has_unknown:
159+
args_str = " == ".join(unknown_arg_names)
160+
if len(unknown_arg_names) > 1:
161+
condition = f", when {args_str}"
162+
else:
163+
condition = f", with unknown index {args_str}"
164+
165+
self._validation_errors.append(
166+
PotentialQubitValidationError(stmt, gate_name, condition)
167+
)
168+
169+
usage = May(
170+
violations=frozenset([f"Unknown qubits at {gate_name} Gate{condition}"])
171+
)
172+
else:
173+
usage = Bottom()
128174

129-
usage = QubitValidation(violations=frozenset(violations))
130175
return tuple(usage for _ in stmt.results) if stmt.results else (usage,)
131176

177+
def _get_source_name(self, value: ir.SSAValue) -> str:
178+
"""Trace back to get the source variable name for a value.
179+
180+
For getitem operations like q[a], returns 'a'.
181+
For direct values, returns the value's name.
182+
"""
183+
from kirin.dialects.py.indexing import GetItem
184+
185+
if isinstance(value, ir.ResultValue) and isinstance(value.stmt, GetItem):
186+
index_arg = value.stmt.args[1]
187+
return self._get_source_name(index_arg)
188+
189+
if isinstance(value, ir.BlockArgument):
190+
return value.name or f"arg{value.index}"
191+
192+
if hasattr(value, "name") and value.name:
193+
return value.name
194+
195+
return str(value)
196+
132197
def run_method(
133198
self, method: ir.Method, args: tuple[QubitValidation, ...]
134199
) -> tuple[ForwardFrame[QubitValidation], QubitValidation]:
135200
self_mt = self.method_self(method)
136201
return self.run_callable(method.code, (self_mt,) + args)
137202

138203
def raise_validation_errors(self):
139-
"""Raise validation error for each no-cloning violation found.
204+
"""Raise validation errors for both definite and potential violations.
140205
Points to source file and line with snippet.
141206
"""
142207
if not self._validation_errors:
143208
return
144209

145-
# If multiple errors, print all with snippets first
146-
if len(self._validation_errors) > 1:
147-
for err in self._validation_errors:
148-
err.attach(self.method)
149-
# Print error message before snippet
210+
# Print all errors with snippets
211+
for err in self._validation_errors:
212+
err.attach(self.method)
213+
214+
# Format error message based on type
215+
if isinstance(err, QubitValidationError):
150216
print(
151-
f"\033[31mValidation Error\033[0m: Cloned qubit [{err.qubit_id}] at {err.gate_name} gate."
217+
f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate"
152218
)
153-
print(err.hint())
154-
print(f"Raised {len(self._validation_errors)} error(s).")
219+
elif isinstance(err, PotentialQubitValidationError):
220+
print(
221+
f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}"
222+
)
223+
else:
224+
print(
225+
f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}"
226+
)
227+
228+
print(err.hint())
229+
155230
raise
Lines changed: 74 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from abc import abstractmethod
12
from typing import FrozenSet, final
23
from dataclasses import field, dataclass
34

@@ -15,9 +16,11 @@ class QubitValidation(
1516
SimpleMeetMixin["QubitValidation"],
1617
BoundedLattice["QubitValidation"],
1718
):
18-
"""Tracks cloning violations detected during analysis."""
19+
"""Base class for qubit cloning validation lattice.
1920
20-
violations: FrozenSet[str] = field(default_factory=frozenset)
21+
Lattice ordering:
22+
Bottom ⊑ May{...} ⊑ Must{...} ⊑ Top
23+
"""
2124

2225
@classmethod
2326
def bottom(cls) -> "QubitValidation":
@@ -26,52 +29,95 @@ def bottom(cls) -> "QubitValidation":
2629

2730
@classmethod
2831
def top(cls) -> "QubitValidation":
29-
"""Unknown state - assume potential violations"""
32+
"""Unknown state"""
3033
return Top()
3134

35+
@abstractmethod
3236
def is_subseteq(self, other: "QubitValidation") -> bool:
33-
"""Check if this state is a subset of another.
34-
35-
Lattice ordering:
36-
Bottom ⊑ {{'Qubit[1] at CX Gate'}} ⊑ {{'Qubit[0] at CX Gate'},{'Qubit[1] at CX Gate'}} ⊑ Top
37-
"""
38-
if isinstance(other, Top):
39-
return True
40-
if isinstance(self, Bottom):
41-
return True
42-
if isinstance(other, Bottom):
43-
return False
44-
45-
return self.violations.issubset(other.violations)
46-
47-
def __repr__(self) -> str:
48-
"""Custom repr to show violations clearly."""
49-
if not self.violations:
50-
return "QubitValidation()"
51-
return f"QubitValidation(violations={self.violations})"
37+
"""Check if this state is a subset of another."""
38+
...
5239

5340

5441
@final
5542
class Bottom(QubitValidation, metaclass=SingletonMeta):
56-
"""Bottom element representing no violations."""
43+
"""Bottom element: no violations detected (safe)."""
5744

5845
def is_subseteq(self, other: QubitValidation) -> bool:
5946
"""Bottom is subset of everything."""
6047
return True
6148

6249
def __repr__(self) -> str:
63-
"""Cleaner printing."""
64-
return "⊥ (Bottom)"
50+
return "⊥ (No Errors)"
6551

6652

6753
@final
6854
class Top(QubitValidation, metaclass=SingletonMeta):
69-
"""Top element representing unknown state with potential violations."""
55+
"""Top element: unknown state (worst case - assume violations possible)."""
7056

7157
def is_subseteq(self, other: QubitValidation) -> bool:
7258
"""Top is only subset of Top."""
7359
return isinstance(other, Top)
7460

7561
def __repr__(self) -> str:
76-
"""Cleaner printing."""
77-
return "⊤ (Top)"
62+
return "⊤ (Unknown)"
63+
64+
65+
@final
66+
@dataclass
67+
class May(QubitValidation):
68+
"""Potential violations that may occur depending on runtime values.
69+
70+
Used when we have unknown addresses (UnknownQubit, UnknownReg, Unknown).
71+
"""
72+
73+
violations: FrozenSet[str] = field(default_factory=frozenset)
74+
75+
def is_subseteq(self, other: QubitValidation) -> bool:
76+
"""May ⊑ May' if violations ⊆ violations'
77+
May ⊑ Must (any may is less precise than must)
78+
May ⊑ Top
79+
"""
80+
match other:
81+
case Bottom():
82+
return False
83+
case May(violations=other_violations):
84+
return self.violations.issubset(other_violations)
85+
case Must():
86+
return True # May is less precise than Must
87+
case Top():
88+
return True
89+
return False
90+
91+
def __repr__(self) -> str:
92+
if not self.violations:
93+
return "MayError(∅)"
94+
return f"MayError({self.violations})"
95+
96+
97+
@final
98+
@dataclass
99+
class Must(QubitValidation):
100+
"""Definite violations with concrete qubit addresses.
101+
102+
These are violations we can prove will definitely occur.
103+
"""
104+
105+
violations: FrozenSet[str] = field(default_factory=frozenset)
106+
107+
def is_subseteq(self, other: QubitValidation) -> bool:
108+
"""Must ⊑ Must' if violations ⊆ violations'
109+
Must ⊑ Top
110+
"""
111+
match other:
112+
case Bottom() | May():
113+
return False
114+
case Must(violations=other_violations):
115+
return self.violations.issubset(other_violations)
116+
case Top():
117+
return True
118+
return False
119+
120+
def __repr__(self) -> str:
121+
if not self.violations:
122+
return "MustError(∅)"
123+
return f"MustError({self.violations})"

0 commit comments

Comments
 (0)