Skip to content

Commit 9ccc42c

Browse files
committed
moved collecting errors to Kirin's InterpreterABC
1 parent b572471 commit 9ccc42c

File tree

4 files changed

+40
-47
lines changed

4 files changed

+40
-47
lines changed

src/bloqade/analysis/validation/nocloning/analysis.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Any
1+
from typing import Any, Sequence
22

33
from kirin import ir
44
from kirin.analysis import Forward
@@ -57,19 +57,15 @@ class _NoCloningAnalysis(Forward[QubitValidation]):
5757
def __init__(self, dialects):
5858
super().__init__(dialects)
5959
self._address_frame: ForwardFrame[Address] | None = None
60-
self._validation_errors: list[ValidationError] = []
6160

6261
def method_self(self, method: ir.Method) -> QubitValidation:
6362
return self.lattice.bottom()
6463

6564
def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidation):
66-
# Set up address frame before analysis if not already cached
6765
if self._address_frame is None:
6866
addr_analysis = AddressAnalysis(self.dialects)
6967
addr_analysis.initialize()
7068
self._address_frame, _ = addr_analysis.run(method)
71-
72-
# Now run the forward analysis with address frame populated
7369
return super().run(method, *args, **kwargs)
7470

7571
def eval_fallback(
@@ -122,9 +118,10 @@ def eval_fallback(
122118
if qubit_addr in seen:
123119
violation = f"Qubit[{qubit_addr}] on {gate_name} Gate"
124120
must_violations.append(violation)
125-
self._validation_errors.append(
126-
QubitValidationError(node, qubit_addr, gate_name)
121+
self.add_validation_error(
122+
node, QubitValidationError(node, qubit_addr, gate_name)
127123
)
124+
128125
seen.add(qubit_addr)
129126

130127
if must_violations:
@@ -136,8 +133,8 @@ def eval_fallback(
136133
else:
137134
condition = f", with unknown argument {args_str}"
138135

139-
self._validation_errors.append(
140-
PotentialQubitValidationError(node, gate_name, condition)
136+
self.add_validation_error(
137+
node, PotentialQubitValidationError(node, gate_name, condition)
141138
)
142139

143140
usage = May(violations=frozenset([f"{gate_name} Gate{condition}"]))
@@ -195,14 +192,14 @@ def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]:
195192
if self._cached_address_frame is not None:
196193
self._analysis._address_frame = self._cached_address_frame
197194
frame, _ = self._analysis.run(method)
198-
199-
return frame, self._analysis._validation_errors
195+
return frame, self._analysis.get_validation_errors()
200196

201197
def print_validation_errors(self):
202198
"""Print all collected errors with formatted snippets."""
203199
if self._analysis is None:
204200
return
205-
for err in self._analysis._validation_errors:
201+
validation_errors = self._analysis.get_validation_errors()
202+
for err in validation_errors:
206203
if isinstance(err, QubitValidationError):
207204
print(
208205
f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate"

src/bloqade/analysis/validation/nocloning/impls.py

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,71 +24,68 @@ def if_else(
2424
except Exception:
2525
cond_validation = Top()
2626

27-
errors_before_then = len(interp_._validation_errors)
27+
errors_before_then_keys = set(interp_._validation_errors.keys())
28+
2829
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
2930
interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation)
3031
frame.set_values(then_frame.entries.keys(), then_frame.entries.values())
31-
errors_after_then = len(interp_._validation_errors)
32+
then_keys = set(interp_._validation_errors.keys()) - errors_before_then_keys
33+
then_errors = interp_.get_validation_errors(keys=then_keys)
3234

33-
then_had_errors = errors_after_then > errors_before_then
34-
then_errors = interp_._validation_errors[errors_before_then:errors_after_then]
3535
then_state = (
3636
Must(violations=frozenset(err.args[0] for err in then_errors))
37-
if then_had_errors
37+
if bool(then_keys)
3838
else Bottom()
3939
)
4040

4141
if stmt.else_body:
42-
errors_before_else = len(interp_._validation_errors)
42+
errors_before_else_keys = set(interp_._validation_errors.keys())
43+
4344
with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
4445
interp_.frame_call_region(
4546
else_frame, stmt, stmt.else_body, cond_validation
4647
)
4748
frame.set_values(else_frame.entries.keys(), else_frame.entries.values())
48-
errors_after_else = len(interp_._validation_errors)
49+
else_keys = set(interp_._validation_errors.keys()) - errors_before_else_keys
50+
else_errors = interp_.get_validation_errors(keys=else_keys)
4951

50-
else_had_errors = errors_after_else > errors_before_else
51-
else_errors = interp_._validation_errors[
52-
errors_before_else:errors_after_else
53-
]
5452
else_state = (
5553
Must(violations=frozenset(err.args[0] for err in else_errors))
56-
if else_had_errors
54+
if bool(else_keys)
5755
else Bottom()
5856
)
5957

6058
merged = then_state.join(else_state)
6159

6260
if isinstance(merged, May):
63-
interp_._validation_errors = interp_._validation_errors[
64-
:errors_before_then
65-
]
61+
branch_keys = then_keys | else_keys
62+
for k in branch_keys:
63+
interp_._validation_errors.pop(k, None)
6664

67-
for err in then_errors + else_errors:
65+
for err in then_errors:
6866
if isinstance(err, QubitValidationError):
6967
potential_err = PotentialQubitValidationError(
70-
err.node,
71-
err.gate_name,
72-
(
73-
", when condition is true"
74-
if err in then_errors
75-
else ", when condition is false"
76-
),
68+
err.node, err.gate_name, ", when condition is true"
7769
)
78-
interp_._validation_errors.append(potential_err)
70+
interp_.add_validation_error(err.node, potential_err)
71+
72+
for err in else_errors:
73+
if isinstance(err, QubitValidationError):
74+
potential_err = PotentialQubitValidationError(
75+
err.node, err.gate_name, ", when condition is false"
76+
)
77+
interp_.add_validation_error(err.node, potential_err)
7978
else:
8079
merged = then_state.join(Bottom())
8180

8281
if isinstance(merged, May):
83-
interp_._validation_errors = interp_._validation_errors[
84-
:errors_before_then
85-
]
86-
82+
for k in then_keys:
83+
interp_._validation_errors.pop(k, None)
8784
for err in then_errors:
8885
if isinstance(err, QubitValidationError):
8986
potential_err = PotentialQubitValidationError(
9087
err.node, err.gate_name, ", when condition is true"
9188
)
92-
interp_._validation_errors.append(potential_err)
89+
interp_.add_validation_error(err.node, potential_err)
9390

9491
return (merged,)

test/analysis/validation/nocloning/test_no_cloning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ def collect_errors_from_validation(
2929

3030
if validation._analysis is None:
3131
return (must_count, may_count)
32-
33-
for err in validation._analysis._validation_errors:
32+
print(validation._analysis.get_validation_errors())
33+
for err in validation._analysis.get_validation_errors():
3434
if isinstance(err, QubitValidationError):
3535
must_count += 1
3636
elif isinstance(err, PotentialQubitValidationError):

test/analysis/validation/test_compose_validation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ def test_validation_suite():
99
@squin.kernel
1010
def bad_kernel(a: int):
1111
q = squin.qalloc(2)
12-
squin.cx(q[0], q[0]) # cloning error
13-
squin.cx(q[a], q[1]) # cloning error
12+
squin.cx(q[0], q[0]) # definite cloning error
13+
squin.cx(q[a], q[1]) # potential cloning error
1414

1515
# Running no-cloning validation multiple times
1616
suite = ValidationSuite(
@@ -35,9 +35,8 @@ def test_validation_suite2():
3535
@squin.kernel
3636
def good_kernel():
3737
q = squin.qalloc(2)
38-
squin.cx(q[0], q[1]) # cloning error
38+
squin.cx(q[0], q[1])
3939

40-
# Running no-cloning validation multiple times
4140
suite = ValidationSuite(
4241
[
4342
NoCloningValidation,

0 commit comments

Comments
 (0)