Skip to content

Commit dca2422

Browse files
committed
Updated ValidationError reporting
1 parent 1b773d4 commit dca2422

File tree

2 files changed

+64
-16
lines changed

2 files changed

+64
-16
lines changed

src/bloqade/analysis/validation/nocloning/analysis.py

Lines changed: 41 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from kirin import ir
44
from kirin.analysis import Forward, TypeInference
55
from kirin.dialects import func
6+
from kirin.ir.exception import ValidationError
67
from kirin.analysis.forward import ForwardFrame
78

89
from bloqade.analysis.address import (
@@ -14,6 +15,19 @@
1415
from .lattice import QubitValidation
1516

1617

18+
class QubitValidationError(ValidationError):
19+
"""ValidationError that records which qubit and gate caused the violation."""
20+
21+
qubit_id: int
22+
gate_name: str
23+
24+
def __init__(self, node: ir.IRNode, qubit_id: int, gate_name: str):
25+
# message stored in ValidationError so formatting/hint() will include it
26+
super().__init__(node, f"Qubit[{qubit_id}] cloned at {gate_name} gate.")
27+
self.qubit_id = qubit_id
28+
self.gate_name = gate_name
29+
30+
1731
class NoCloningValidation(Forward[QubitValidation]):
1832
"""
1933
Validates the no-cloning theorem by tracking qubit addresses.
@@ -26,7 +40,9 @@ class NoCloningValidation(Forward[QubitValidation]):
2640
_address_frame: ForwardFrame[Address] = field(init=False)
2741
_type_frame: ForwardFrame = field(init=False)
2842
method: ir.Method
29-
_validation_errors: list[str] = field(default_factory=list, init=False)
43+
_validation_errors: list[QubitValidationError] = field(
44+
default_factory=list, init=False
45+
)
3046

3147
def __init__(self, mtd: ir.Method):
3248
"""
@@ -63,13 +79,9 @@ def get_qubit_addresses(self, addr: Address) -> frozenset[int]:
6379
case _:
6480
return frozenset()
6581

66-
def get_stmt_info(self, stmt: ir.Statement) -> str:
67-
"""String Report about the statement for violation messages."""
68-
if isinstance(stmt, func.Invoke) and hasattr(stmt, "callee"):
69-
gate_name = stmt.callee.sym_name.upper()
70-
return f"{gate_name} Gate"
71-
72-
return f"{stmt.__class__.__name__}@{stmt}"
82+
def format_violation(self, qubit_id: int, gate_name: str) -> str:
83+
"""Return the violation string for a qubit + gate."""
84+
return f"Qubit[{qubit_id}] on {gate_name} Gate"
7385

7486
def eval_stmt_fallback(
7587
self, frame: ForwardFrame[QubitValidation], stmt: ir.Statement
@@ -101,13 +113,13 @@ def eval_stmt_fallback(
101113

102114
seen: set[int] = set()
103115
violations: list[str] = []
104-
stmt_info = self.get_stmt_info(stmt)
105116

106117
for qubit_addr in used_addrs:
107118
if qubit_addr in seen:
108-
violations.append(f"Qubit[{qubit_addr}] on {stmt_info}")
119+
gate_name = stmt.callee.sym_name.upper()
120+
violations.append(self.format_violation(qubit_addr, gate_name))
109121
self._validation_errors.append(
110-
f"Qubit[{qubit_addr}] on {stmt_info} in {stmt.source}"
122+
QubitValidationError(stmt, qubit_addr, gate_name)
111123
)
112124
seen.add(qubit_addr)
113125

@@ -123,6 +135,21 @@ def run_method(
123135
self_mt = self.method_self(method)
124136
return self.run_callable(method.code, (self_mt,) + args)
125137

126-
def get_validation_errors(self) -> str:
127-
"""Retrieve collected validation error messages."""
128-
return "\n".join(self._validation_errors)
138+
def raise_validation_errors(self):
139+
"""Raise validation error for each no-cloning violation found.
140+
Points to source file and line with snippet.
141+
"""
142+
if not self._validation_errors:
143+
return
144+
145+
# If multiple errors, print all with snippets first
146+
if len(self._validation_errors) > 1:
147+
for err in self._validation_errors:
148+
err.attach(self.method)
149+
# Print error message before snippet
150+
print(
151+
f"\033[31mValidation Error\033[0m: Cloned qubit [{err.qubit_id}] at {err.gate_name} gate."
152+
)
153+
print(err.hint())
154+
print(f"Raised {len(self._validation_errors)} error(s).")
155+
raise

test/analysis/validation/test_no_cloning.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ def bad_control():
2525
bad_control.print(analysis=frame.entries)
2626
validation_errors = collect_validation_errors(frame, QubitValidation)
2727
assert len(validation_errors) == 1
28+
with pytest.raises(Exception):
29+
validation.raise_validation_errors()
2830

2931

3032
@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz])
@@ -45,6 +47,8 @@ def bad_control(cond: bool):
4547
bad_control.print(analysis=frame.entries)
4648
validation_errors = collect_validation_errors(frame, QubitValidation)
4749
assert len(validation_errors) == 2
50+
with pytest.raises(Exception):
51+
validation.raise_validation_errors()
4852

4953

5054
@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz])
@@ -75,7 +79,8 @@ def good_kernel():
7579
frame, _ = validation.run_analysis(good_kernel)
7680
validation_errors = collect_validation_errors(frame, QubitValidation)
7781
assert len(validation_errors) == 1
78-
print(validation.get_validation_errors())
82+
with pytest.raises(Exception):
83+
validation.raise_validation_errors()
7984

8085

8186
def test_parallel_fail():
@@ -91,4 +96,20 @@ def bad_kernel():
9196
bad_kernel.print(analysis=frame.entries)
9297
validation_errors = collect_validation_errors(frame, QubitValidation)
9398
assert len(validation_errors) == 2
94-
print(validation.get_validation_errors())
99+
with pytest.raises(Exception):
100+
validation.raise_validation_errors()
101+
102+
103+
# def test_potential_fail():
104+
# @squin.kernel
105+
# def bad_kernel(a: int, b: int):
106+
# q = squin.qalloc(5)
107+
# squin.cx(q[a], q[b])
108+
109+
# validation = NoCloningValidation(bad_kernel)
110+
# validation.initialize()
111+
# frame, _ = validation.run_analysis(bad_kernel)
112+
# print()
113+
# bad_kernel.print(analysis=frame.entries)
114+
# validation_errors = collect_validation_errors(frame, QubitValidation)
115+
# assert len(validation_errors) == 0

0 commit comments

Comments
 (0)