Skip to content

Commit b572471

Browse files
committed
updated to work with new Kirin version
1 parent c3680c4 commit b572471

File tree

3 files changed

+28
-24
lines changed

3 files changed

+28
-24
lines changed

src/bloqade/analysis/validation/nocloning/analysis.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__(self, node: ir.IRNode, gate_name: str, condition: str):
5151
class _NoCloningAnalysis(Forward[QubitValidation]):
5252
"""Internal forward analysis for tracking qubit cloning violations."""
5353

54-
keys = ["validate.nocloning"]
54+
keys = ("validate.nocloning",)
5555
lattice = QubitValidation
5656

5757
def __init__(self, dialects):
@@ -62,33 +62,32 @@ def __init__(self, dialects):
6262
def method_self(self, method: ir.Method) -> QubitValidation:
6363
return self.lattice.bottom()
6464

65-
def run_method(
66-
self, method: ir.Method, args: tuple[QubitValidation, ...]
67-
) -> tuple[ForwardFrame[QubitValidation], QubitValidation]:
65+
def run(self, method: ir.Method, *args: QubitValidation, **kwargs: QubitValidation):
66+
# Set up address frame before analysis if not already cached
6867
if self._address_frame is None:
6968
addr_analysis = AddressAnalysis(self.dialects)
7069
addr_analysis.initialize()
71-
self._address_frame, _ = addr_analysis.run_analysis(method)
70+
self._address_frame, _ = addr_analysis.run(method)
7271

73-
return self.run_callable(method.code, args)
72+
# Now run the forward analysis with address frame populated
73+
return super().run(method, *args, **kwargs)
7474

75-
def eval_stmt_fallback(
76-
self, frame: ForwardFrame[QubitValidation], stmt: ir.Statement
75+
def eval_fallback(
76+
self, frame: ForwardFrame[QubitValidation], node: ir.Statement
7777
) -> tuple[QubitValidation, ...]:
7878
"""Check for qubit usage violations."""
79-
if not isinstance(stmt, func.Invoke):
80-
return tuple(Bottom() for _ in stmt.results)
79+
if not isinstance(node, func.Invoke):
80+
return tuple(Bottom() for _ in node.results)
8181

8282
address_frame = self._address_frame
8383
if address_frame is None:
84-
return tuple(Top() for _ in stmt.results)
84+
return tuple(Top() for _ in node.results)
8585

8686
concrete_addrs: list[int] = []
8787
has_unknown = False
8888
has_qubit_args = False
8989
unknown_arg_names: list[str] = []
90-
91-
for arg in stmt.args:
90+
for arg in node.args:
9291
addr = address_frame.get(arg)
9392
match addr:
9493
case AddressQubit(data=qubit_addr):
@@ -112,18 +111,19 @@ def eval_stmt_fallback(
112111
pass
113112

114113
if not has_qubit_args:
115-
return tuple(Bottom() for _ in stmt.results)
114+
return tuple(Bottom() for _ in node.results)
116115

117116
seen: set[int] = set()
118117
must_violations: list[str] = []
119-
gate_name = stmt.callee.sym_name.upper()
118+
s_name = getattr(node.callee, "sym_name", "<unknown")
119+
gate_name = s_name.upper()
120120

121121
for qubit_addr in concrete_addrs:
122122
if qubit_addr in seen:
123123
violation = f"Qubit[{qubit_addr}] on {gate_name} Gate"
124124
must_violations.append(violation)
125125
self._validation_errors.append(
126-
QubitValidationError(stmt, qubit_addr, gate_name)
126+
QubitValidationError(node, qubit_addr, gate_name)
127127
)
128128
seen.add(qubit_addr)
129129

@@ -137,14 +137,14 @@ def eval_stmt_fallback(
137137
condition = f", with unknown argument {args_str}"
138138

139139
self._validation_errors.append(
140-
PotentialQubitValidationError(stmt, gate_name, condition)
140+
PotentialQubitValidationError(node, gate_name, condition)
141141
)
142142

143143
usage = May(violations=frozenset([f"{gate_name} Gate{condition}"]))
144144
else:
145145
usage = Bottom()
146146

147-
return tuple(usage for _ in stmt.results) if stmt.results else (usage,)
147+
return tuple(usage for _ in node.results) if node.results else (usage,)
148148

149149
def _get_source_name(self, value: ir.SSAValue) -> str:
150150
"""Trace back to get the source variable name."""
@@ -194,7 +194,7 @@ def run(self, method: ir.Method) -> tuple[Any, list[ValidationError]]:
194194
self._analysis.initialize()
195195
if self._cached_address_frame is not None:
196196
self._analysis._address_frame = self._cached_address_frame
197-
frame, _ = self._analysis.run_analysis(method)
197+
frame, _ = self._analysis.run(method)
198198

199199
return frame, self._analysis._validation_errors
200200

src/bloqade/analysis/validation/nocloning/impls.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ def if_else(
2525
cond_validation = Top()
2626

2727
errors_before_then = len(interp_._validation_errors)
28-
_ = interp_.run_callable_region(frame, stmt, stmt.then_body, (cond_validation,))
28+
with interp_.new_frame(stmt, has_parent_access=True) as then_frame:
29+
interp_.frame_call_region(then_frame, stmt, stmt.then_body, cond_validation)
30+
frame.set_values(then_frame.entries.keys(), then_frame.entries.values())
2931
errors_after_then = len(interp_._validation_errors)
3032

3133
then_had_errors = errors_after_then > errors_before_then
@@ -38,9 +40,11 @@ def if_else(
3840

3941
if stmt.else_body:
4042
errors_before_else = len(interp_._validation_errors)
41-
_ = interp_.run_callable_region(
42-
frame, stmt, stmt.else_body, (cond_validation,)
43-
)
43+
with interp_.new_frame(stmt, has_parent_access=True) as else_frame:
44+
interp_.frame_call_region(
45+
else_frame, stmt, stmt.else_body, cond_validation
46+
)
47+
frame.set_values(else_frame.entries.keys(), else_frame.entries.values())
4448
errors_after_else = len(interp_._validation_errors)
4549

4650
else_had_errors = errors_after_else > errors_before_else

src/bloqade/analysis/validation/validationpass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def validate(self, method: ir.Method) -> "ValidationResult":
8787
if required_analysis not in self._analysis_cache:
8888
analysis = required_analysis(method.dialects)
8989
analysis.initialize()
90-
frame, _ = analysis.run_analysis(method)
90+
frame, _ = analysis.run(method)
9191
self._analysis_cache[required_analysis] = frame
9292

9393
validator.set_analysis_cache(self._analysis_cache)

0 commit comments

Comments
 (0)