Skip to content

Commit 924d60a

Browse files
committed
Refactor validation analysis and error handling in NoCloningValidation. Simplified error handling of Scf.Ifelse
1 parent 06c7bbb commit 924d60a

File tree

5 files changed

+166
-170
lines changed

5 files changed

+166
-170
lines changed

src/bloqade/analysis/validation/nocloning/analysis.py

Lines changed: 74 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidati
7171
def eval_fallback(
7272
self, frame: ForwardFrame[QubitValidation], node: ir.Statement
7373
) -> tuple[QubitValidation, ...]:
74-
"""Check for qubit usage violations."""
74+
"""Check for qubit usage violations and return lattice values."""
7575
if not isinstance(node, func.Invoke):
7676
return tuple(Bottom() for _ in node.results)
7777

@@ -83,6 +83,7 @@ def eval_fallback(
8383
has_unknown = False
8484
has_qubit_args = False
8585
unknown_arg_names: list[str] = []
86+
8687
for arg in node.args:
8788
addr = address_frame.get(arg)
8889
match addr:
@@ -110,34 +111,25 @@ def eval_fallback(
110111
return tuple(Bottom() for _ in node.results)
111112

112113
seen: set[int] = set()
113-
must_violations: list[str] = []
114-
s_name = getattr(node.callee, "sym_name", "<unknown")
114+
violations: set[tuple[int, str]] = set()
115+
s_name = getattr(node.callee, "sym_name", "<unknown>")
115116
gate_name = s_name.upper()
116117

117118
for qubit_addr in concrete_addrs:
118119
if qubit_addr in seen:
119-
violation = f"Qubit[{qubit_addr}] on {gate_name} Gate"
120-
must_violations.append(violation)
121-
self.add_validation_error(
122-
node, QubitValidationError(node, qubit_addr, gate_name)
123-
)
124-
120+
violations.add((qubit_addr, gate_name))
125121
seen.add(qubit_addr)
126122

127-
if must_violations:
128-
usage = Must(violations=frozenset(must_violations))
123+
if violations:
124+
usage = Must(violations=frozenset(violations))
129125
elif has_unknown:
130126
args_str = " == ".join(unknown_arg_names)
131127
if len(unknown_arg_names) > 1:
132128
condition = f", when {args_str}"
133129
else:
134130
condition = f", with unknown argument {args_str}"
135131

136-
self.add_validation_error(
137-
node, PotentialQubitValidationError(node, gate_name, condition)
138-
)
139-
140-
usage = May(violations=frozenset([f"{gate_name} Gate{condition}"]))
132+
usage = May(violations=frozenset([(gate_name, condition)]))
141133
else:
142134
usage = Bottom()
143135

@@ -159,6 +151,48 @@ def _get_source_name(self, value: ir.SSAValue) -> str:
159151

160152
return str(value)
161153

154+
def extract_errors_from_frame(
155+
self, frame: ForwardFrame[QubitValidation]
156+
) -> list[ValidationError]:
157+
"""Extract validation errors from final lattice values.
158+
159+
Only extracts errors from top-level statements (not nested in regions).
160+
"""
161+
errors = []
162+
seen_statements = set()
163+
164+
for node, value in frame.entries.items():
165+
if isinstance(node, ir.ResultValue):
166+
stmt = node.stmt
167+
elif isinstance(node, ir.Statement):
168+
stmt = node
169+
else:
170+
continue
171+
if stmt in seen_statements:
172+
continue
173+
seen_statements.add(stmt)
174+
if isinstance(value, Must):
175+
for qubit_id, gate_name in value.violations:
176+
errors.append(QubitValidationError(stmt, qubit_id, gate_name))
177+
elif isinstance(value, May):
178+
for gate_name, condition in value.violations:
179+
errors.append(
180+
PotentialQubitValidationError(stmt, gate_name, condition)
181+
)
182+
return errors
183+
184+
def count_violations(self, frame: Any) -> int:
185+
"""Count individual violations from the frame, same as test helper."""
186+
from .lattice import May, Must
187+
188+
total = 0
189+
for node, value in frame.entries.items():
190+
if isinstance(value, Must):
191+
total += len(value.violations)
192+
elif isinstance(value, May):
193+
total += len(value.violations)
194+
return total
195+
162196

163197
class NoCloningValidation(ValidationPass):
164198
"""Validates the no-cloning theorem by tracking qubit addresses."""
@@ -179,37 +213,39 @@ def set_analysis_cache(self, cache: dict[type, Any]) -> None:
179213
self._cached_address_frame = cache.get(AddressAnalysis)
180214

181215
def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]:
182-
"""Run the no-cloning validation analysis.
183-
184-
Returns:
185-
- frame: ForwardFrame with QubitValidation lattice values
186-
- errors: List of validation errors found
187-
"""
216+
"""Run the no-cloning validation analysis."""
188217
if self._analysis is None:
189218
self._analysis = _NoCloningAnalysis(method.dialects)
190219

191220
self._analysis.initialize()
192221
if self._cached_address_frame is not None:
193222
self._analysis._address_frame = self._cached_address_frame
223+
194224
frame, _ = self._analysis.run(method)
195-
return frame, self._analysis.get_validation_errors()
225+
errors = self._analysis.extract_errors_from_frame(frame)
226+
227+
return frame, errors
196228

197229
def print_validation_errors(self):
198230
"""Print all collected errors with formatted snippets."""
199231
if self._analysis is None:
200232
return
201-
validation_errors = self._analysis.get_validation_errors()
202-
for err in validation_errors:
203-
if isinstance(err, QubitValidationError):
204-
print(
205-
f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate"
206-
)
207-
elif isinstance(err, PotentialQubitValidationError):
208-
print(
209-
f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}"
210-
)
211-
else:
212-
print(
213-
f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}"
214-
)
215-
print(err.hint())
233+
234+
if self._analysis.state._current_frame:
235+
frame = self._analysis.state._current_frame
236+
errors = self._analysis.extract_errors_from_frame(frame)
237+
238+
for err in errors:
239+
if isinstance(err, QubitValidationError):
240+
print(
241+
f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate"
242+
)
243+
elif isinstance(err, PotentialQubitValidationError):
244+
print(
245+
f"\n\033[33mWarning\033[0m: Potential cloning at {err.gate_name} gate{err.condition}"
246+
)
247+
else:
248+
print(
249+
f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}"
250+
)
251+
print(err.hint())

src/bloqade/analysis/validation/nocloning/impls.py

Lines changed: 23 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,7 @@
33
from kirin.dialects import scf
44

55
from .lattice import May, Top, Must, Bottom, QubitValidation
6-
from .analysis import (
7-
QubitValidationError,
8-
PotentialQubitValidationError,
9-
_NoCloningAnalysis,
10-
)
6+
from .analysis import _NoCloningAnalysis
117

128

139
@scf.dialect.register(key="validate.nocloning")
@@ -24,63 +20,40 @@ def if_else(
2420
except Exception:
2521
cond_validation = Top()
2622

27-
errors_before = set(interp_._validation_errors.keys())
28-
2923
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
3024
interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation)
31-
frame.set_values(then_frame.entries.keys(), then_frame.entries.values())
32-
33-
then_keys = set(interp_._validation_errors.keys()) - errors_before
34-
then_errors = interp_.get_validation_errors(keys=then_keys)
3525

36-
then_state = (
37-
Must(violations=frozenset(err.args[0] for err in then_errors))
38-
if then_keys
39-
else Bottom()
40-
)
26+
then_state = Bottom()
27+
for node, val in then_frame.entries.items():
28+
if isinstance(val, (Must, May)):
29+
then_state = then_state.join(val)
4130

31+
else_state = Bottom()
4232
if stmt.else_body:
43-
errors_before_else = set(interp_._validation_errors.keys())
44-
4533
with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
4634
interp_.frame_call_region(
4735
else_frame, stmt, stmt.else_body, cond_validation
4836
)
49-
frame.set_values(else_frame.entries.keys(), else_frame.entries.values())
5037

51-
else_keys = set(interp_._validation_errors.keys()) - errors_before_else
52-
else_errors = interp_.get_validation_errors(keys=else_keys)
38+
for node, val in else_frame.entries.items():
39+
if isinstance(val, (Must, May)):
40+
else_state = else_state.join(val)
5341

54-
else_state = (
55-
Must(violations=frozenset(err.args[0] for err in else_errors))
56-
if else_keys
57-
else Bottom()
58-
)
59-
else:
60-
else_state = Bottom()
61-
else_keys = set()
62-
else_errors = []
6342
merged = then_state.join(else_state)
64-
all_branch_keys = then_keys | else_keys
65-
for k in all_branch_keys:
66-
interp_._validation_errors.pop(k, None)
6743

68-
if isinstance(merged, Must):
69-
for err in then_errors + else_errors:
70-
if isinstance(err, QubitValidationError):
71-
interp_.add_validation_error(err.node, err)
72-
elif isinstance(merged, May):
73-
for err in then_errors:
74-
if isinstance(err, QubitValidationError):
75-
potential_err = PotentialQubitValidationError(
76-
err.node, err.gate_name, ", when condition is true"
77-
)
78-
interp_.add_validation_error(err.node, potential_err)
44+
if isinstance(merged, May):
45+
then_has = not isinstance(then_state, Bottom)
46+
else_has = not isinstance(else_state, Bottom)
47+
48+
if then_has and not else_has:
49+
new_violations = frozenset(
50+
(gate, ", when condition is true") for gate, _ in merged.violations
51+
)
52+
merged = May(violations=new_violations)
53+
elif else_has and not then_has:
54+
new_violations = frozenset(
55+
(gate, ", when condition is false") for gate, _ in merged.violations
56+
)
57+
merged = May(violations=new_violations)
7958

80-
for err in else_errors:
81-
if isinstance(err, QubitValidationError):
82-
potential_err = PotentialQubitValidationError(
83-
err.node, err.gate_name, ", when condition is false"
84-
)
85-
interp_.add_validation_error(err.node, potential_err)
8659
return (merged,)

0 commit comments

Comments
 (0)