Skip to content

Commit a35636a

Browse files
committed
Improve Validation framework to compose multiple validation analyses.
1 parent 2a011fc commit a35636a

File tree

7 files changed

+544
-246
lines changed

7 files changed

+544
-246
lines changed

src/bloqade/analysis/validation/nocloning/analysis.py

Lines changed: 73 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from dataclasses import field
1+
from typing import Any
22

33
from kirin import ir
4-
from kirin.analysis import Forward, TypeInference
4+
from kirin.analysis import Forward
55
from kirin.dialects import func
66
from kirin.ir.exception import ValidationError
77
from kirin.analysis.forward import ForwardFrame
@@ -15,8 +15,11 @@
1515
AddressReg,
1616
UnknownReg,
1717
AddressQubit,
18+
PartialIList,
19+
PartialTuple,
1820
UnknownQubit,
1921
)
22+
from bloqade.analysis.validation.validationpass import ValidationPass
2023

2124
from .lattice import May, Top, Must, Bottom, QubitValidation
2225

@@ -45,66 +48,39 @@ def __init__(self, node: ir.IRNode, gate_name: str, condition: str):
4548
self.condition = condition
4649

4750

48-
class NoCloningValidation(Forward[QubitValidation]):
49-
"""
50-
Validates the no-cloning theorem by tracking qubit addresses.
51-
52-
Built on top of AddressAnalysis to get qubit address information.
53-
"""
51+
class _NoCloningAnalysis(Forward[QubitValidation]):
52+
"""Internal forward analysis for tracking qubit cloning violations."""
5453

5554
keys = ["validate.nocloning"]
5655
lattice = QubitValidation
57-
_address_frame: ForwardFrame[Address] = field(init=False)
58-
_type_frame: ForwardFrame = field(init=False)
59-
method: ir.Method
60-
_validation_errors: list[ValidationError] = field(default_factory=list, init=False)
6156

62-
def __init__(self, mtd: ir.Method):
63-
"""
64-
Input:
65-
- an ir.Method / kernel function
66-
infer dialects from it and remember method.
67-
"""
68-
self.method = mtd
69-
super().__init__(mtd.dialects)
57+
def __init__(self, dialects):
58+
super().__init__(dialects)
59+
self._address_frame: ForwardFrame[Address] | None = None
60+
self._validation_errors: list[ValidationError] = []
7061

7162
def initialize(self):
7263
super().initialize()
7364
self._validation_errors = []
74-
address_analysis = AddressAnalysis(self.dialects)
75-
address_analysis.initialize()
76-
self._address_frame, _ = address_analysis.run_analysis(self.method)
77-
78-
type_inference = TypeInference(self.dialects)
79-
type_inference.initialize()
80-
self._type_frame, _ = type_inference.run_analysis(self.method)
81-
8265
return self
8366

84-
def method_self(self, method: ir.Method) -> QubitValidation:
85-
return self.lattice.bottom()
67+
def run_method(
68+
self, method: ir.Method, args: tuple[QubitValidation, ...]
69+
) -> tuple[ForwardFrame[QubitValidation], QubitValidation]:
70+
if self._address_frame is None:
71+
if getattr(self, "_address_analysis", None) is None:
72+
addr_analysis = AddressAnalysis(self.dialects)
73+
addr_analysis.initialize()
74+
self._address_analysis = addr_analysis
8675

87-
def get_qubit_addresses(self, addr: Address) -> frozenset[int]:
88-
"""Extract concrete qubit addresses from an Address lattice element."""
89-
match addr:
90-
case AddressQubit(data=qubit_addr):
91-
return frozenset([qubit_addr])
92-
case AddressReg(data=addrs):
93-
return frozenset(addrs)
94-
case _:
95-
return frozenset()
76+
self._address_frame, _ = self._address_analysis.run_analysis(method)
9677

97-
def format_violation(self, qubit_id: int, gate_name: str) -> str:
98-
"""Return the violation string for a qubit + gate."""
99-
return f"Qubit[{qubit_id}] on {gate_name} Gate"
78+
return self.run_callable(method.code, args)
10079

10180
def eval_stmt_fallback(
10281
self, frame: ForwardFrame[QubitValidation], stmt: ir.Statement
10382
) -> tuple[QubitValidation, ...]:
104-
"""
105-
Default statement evaluation: check for qubit usage violations.
106-
Returns Bottom, May, Must, or Top depending on what we can prove.
107-
"""
83+
"""Check for qubit usage violations."""
10884

10985
if not isinstance(stmt, func.Invoke):
11086
return tuple(Bottom() for _ in stmt.results)
@@ -127,7 +103,13 @@ def eval_stmt_fallback(
127103
case AddressReg(data=addrs):
128104
has_qubit_args = True
129105
concrete_addrs.extend(addrs)
130-
case UnknownQubit() | UnknownReg() | Unknown():
106+
case (
107+
UnknownQubit()
108+
| UnknownReg()
109+
| PartialIList()
110+
| PartialTuple()
111+
| Unknown()
112+
):
131113
has_qubit_args = True
132114
has_unknown = True
133115
arg_name = self._get_source_name(arg)
@@ -144,7 +126,7 @@ def eval_stmt_fallback(
144126

145127
for qubit_addr in concrete_addrs:
146128
if qubit_addr in seen:
147-
violation = self.format_violation(qubit_addr, gate_name)
129+
violation = f"Qubit[{qubit_addr}] on {gate_name} Gate"
148130
must_violations.append(violation)
149131
self._validation_errors.append(
150132
QubitValidationError(stmt, qubit_addr, gate_name)
@@ -171,11 +153,7 @@ def eval_stmt_fallback(
171153
return tuple(usage for _ in stmt.results) if stmt.results else (usage,)
172154

173155
def _get_source_name(self, value: ir.SSAValue) -> str:
174-
"""Trace back to get the source variable name for a value.
175-
176-
For getitem operations like q[a], returns 'a'.
177-
For direct values, returns the value's name.
178-
"""
156+
"""Trace back to get the source variable name."""
179157
from kirin.dialects.py.indexing import GetItem
180158

181159
if isinstance(value, ir.ResultValue) and isinstance(value.stmt, GetItem):
@@ -190,24 +168,52 @@ def _get_source_name(self, value: ir.SSAValue) -> str:
190168

191169
return str(value)
192170

193-
def run_method(
194-
self, method: ir.Method, args: tuple[QubitValidation, ...]
195-
) -> tuple[ForwardFrame[QubitValidation], QubitValidation]:
196-
self_mt = self.method_self(method)
197-
return self.run_callable(method.code, (self_mt,) + args)
198171

199-
def raise_validation_errors(self):
200-
"""Raise validation errors for both definite and potential violations.
201-
Points to source file and line with snippet.
172+
class NoCloningValidation(ValidationPass):
173+
"""Validates the no-cloning theorem by tracking qubit addresses."""
174+
175+
def __init__(self):
176+
self.method: ir.Method | None = None
177+
self._analysis: _NoCloningAnalysis | None = None
178+
self._cached_address_frame = None
179+
180+
def name(self) -> str:
181+
return "No-Cloning Validation"
182+
183+
def get_required_analyses(self) -> list[type]:
184+
"""Declare dependency on AddressAnalysis."""
185+
return [AddressAnalysis]
186+
187+
def set_analysis_cache(self, cache: dict[type, Any]) -> None:
188+
"""Use cached AddressAnalysis result."""
189+
self._cached_address_frame = cache.get(AddressAnalysis)
190+
191+
def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]:
192+
"""Run the no-cloning validation analysis.
193+
194+
Returns:
195+
- frame: ForwardFrame with QubitValidation lattice values
196+
- errors: List of validation errors found
202197
"""
203-
if not self._validation_errors:
204-
return
198+
if self._analysis is None:
199+
self._analysis = _NoCloningAnalysis(method.dialects)
200+
201+
self.method = method
202+
self._analysis.initialize()
203+
if self._cached_address_frame is not None:
204+
self._analysis._address_frame = self._cached_address_frame
205+
frame, _ = self._analysis.run_analysis(method, args=None)
205206

206-
# Print all errors with snippets
207-
for err in self._validation_errors:
208-
err.attach(self.method)
207+
return frame, self._analysis._validation_errors
209208

210-
# Format error message based on type
209+
def print_validation_errors(self):
210+
"""Print all collected errors with formatted snippets."""
211+
if self._analysis is None:
212+
return
213+
errors = self._analysis._validation_errors
214+
if not errors:
215+
return
216+
for err in errors:
211217
if isinstance(err, QubitValidationError):
212218
print(
213219
f"\n\033[31mError\033[0m: Cloning qubit [{err.qubit_id}] at {err.gate_name} gate"
@@ -220,7 +226,4 @@ def raise_validation_errors(self):
220226
print(
221227
f"\n\033[31mError\033[0m: {err.args[0] if err.args else type(err).__name__}"
222228
)
223-
224229
print(err.hint())
225-
226-
raise

src/bloqade/analysis/validation/nocloning/impls.py

Lines changed: 67 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,32 +2,89 @@
22
from kirin.analysis import ForwardFrame
33
from kirin.dialects import scf
44

5-
from .lattice import QubitValidation
6-
from .analysis import NoCloningValidation
5+
from .lattice import May, Top, Must, Bottom, QubitValidation
6+
from .analysis import (
7+
QubitValidationError,
8+
PotentialQubitValidationError,
9+
_NoCloningAnalysis,
10+
)
711

812

913
@scf.dialect.register(key="validate.nocloning")
1014
class Scf(interp.MethodTable):
1115
@interp.impl(scf.IfElse)
1216
def if_else(
1317
self,
14-
interp_: NoCloningValidation,
18+
interp_: _NoCloningAnalysis,
1519
frame: ForwardFrame[QubitValidation],
1620
stmt: scf.IfElse,
1721
):
18-
cond_validation = frame.get(stmt.cond)
22+
try:
23+
cond_validation = frame.get(stmt.cond)
24+
except Exception:
25+
cond_validation = Top()
1926

20-
then_results = interp_.run_callable_region(
21-
frame, stmt, stmt.then_body, (cond_validation,)
27+
errors_before_then = len(interp_._validation_errors)
28+
_ = interp_.run_callable_region(frame, stmt, stmt.then_body, (cond_validation,))
29+
errors_after_then = len(interp_._validation_errors)
30+
31+
then_had_errors = errors_after_then > errors_before_then
32+
then_errors = interp_._validation_errors[errors_before_then:errors_after_then]
33+
then_state = (
34+
Must(violations=frozenset(err.args[0] for err in then_errors))
35+
if then_had_errors
36+
else Bottom()
2237
)
2338

2439
if stmt.else_body:
25-
else_results = interp_.run_callable_region(
40+
errors_before_else = len(interp_._validation_errors)
41+
_ = interp_.run_callable_region(
2642
frame, stmt, stmt.else_body, (cond_validation,)
2743
)
44+
errors_after_else = len(interp_._validation_errors)
45+
46+
else_had_errors = errors_after_else > errors_before_else
47+
else_errors = interp_._validation_errors[
48+
errors_before_else:errors_after_else
49+
]
50+
else_state = (
51+
Must(violations=frozenset(err.args[0] for err in else_errors))
52+
if else_had_errors
53+
else Bottom()
54+
)
55+
56+
merged = then_state.join(else_state)
2857

29-
merged = tuple(then_results.join(else_results) for _ in stmt.results)
58+
if isinstance(merged, May):
59+
interp_._validation_errors = interp_._validation_errors[
60+
:errors_before_then
61+
]
62+
63+
for err in then_errors + else_errors:
64+
if isinstance(err, QubitValidationError):
65+
potential_err = PotentialQubitValidationError(
66+
err.node,
67+
err.gate_name,
68+
(
69+
", when condition is true"
70+
if err in then_errors
71+
else ", when condition is false"
72+
),
73+
)
74+
interp_._validation_errors.append(potential_err)
3075
else:
31-
merged = tuple(then_results for _ in stmt.results)
76+
merged = then_state.join(Bottom())
77+
78+
if isinstance(merged, May):
79+
interp_._validation_errors = interp_._validation_errors[
80+
:errors_before_then
81+
]
82+
83+
for err in then_errors:
84+
if isinstance(err, QubitValidationError):
85+
potential_err = PotentialQubitValidationError(
86+
err.node, err.gate_name, ", when condition is true"
87+
)
88+
interp_._validation_errors.append(potential_err)
3289

33-
return merged if merged else (QubitValidation.bottom(),)
90+
return (merged,)

src/bloqade/analysis/validation/nocloning/lattice.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,25 @@ def is_subseteq(self, other: QubitValidation) -> bool:
8787
return False
8888

8989
def join(self, other: QubitValidation) -> QubitValidation:
90+
"""Join with another validation state.
91+
92+
Key insight: Must ⊔ Bottom = May (error on one path, not all)
93+
"""
9094
match other:
9195
case Bottom():
92-
return self
96+
# Error in one branch, safe in other = May (conditional error)
97+
result = May(violations=self.violations)
98+
return result
9399
case Must(violations=ov):
94-
return Must(violations=self.violations | ov)
100+
# Errors in both branches
101+
common = self.violations & ov
102+
all_violations = self.violations | ov
103+
if common == all_violations:
104+
# Same errors on all paths = Must
105+
return Must(violations=all_violations)
106+
else:
107+
# Different errors on different paths = May
108+
return May(violations=all_violations)
95109
case May(violations=ov):
96110
return May(violations=self.violations | ov)
97111
case Top():

0 commit comments

Comments
 (0)