Skip to content

Commit 1b773d4

Browse files
committed
improve error reporting and update test cases for validation errors
1 parent 64dd86a commit 1b773d4

File tree

3 files changed

+44
-24
lines changed

3 files changed

+44
-24
lines changed

src/bloqade/analysis/validation/nocloning/analysis.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,9 @@
77

88
from bloqade.analysis.address import (
99
Address,
10-
AddressReg,
11-
AddressQubit,
1210
AddressAnalysis,
1311
)
14-
from bloqade.analysis.address.lattice import QubitLike
12+
from bloqade.analysis.address.lattice import AddressReg, AddressQubit
1513

1614
from .lattice import QubitValidation
1715

@@ -28,7 +26,7 @@ class NoCloningValidation(Forward[QubitValidation]):
2826
_address_frame: ForwardFrame[Address] = field(init=False)
2927
_type_frame: ForwardFrame = field(init=False)
3028
method: ir.Method
31-
violations: int = field(default=0, init=False)
29+
_validation_errors: list[str] = field(default_factory=list, init=False)
3230

3331
def __init__(self, mtd: ir.Method):
3432
"""
@@ -41,7 +39,7 @@ def __init__(self, mtd: ir.Method):
4139

4240
def initialize(self):
4341
super().initialize()
44-
42+
self._validation_errors = []
4543
address_analysis = AddressAnalysis(self.dialects)
4644
address_analysis.initialize()
4745
self._address_frame, _ = address_analysis.run_analysis(self.method)
@@ -88,7 +86,8 @@ def eval_stmt_fallback(
8886
return tuple(QubitValidation.top() for _ in stmt.results)
8987

9088
has_qubit_args = any(
91-
isinstance(address_frame.get(arg), QubitLike) for arg in stmt.args
89+
isinstance(address_frame.get(arg), (AddressQubit, AddressReg))
90+
for arg in stmt.args
9291
)
9392

9493
if not has_qubit_args:
@@ -106,7 +105,10 @@ def eval_stmt_fallback(
106105

107106
for qubit_addr in used_addrs:
108107
if qubit_addr in seen:
109-
violations.append(f"Qubit[{qubit_addr}] at {stmt_info}")
108+
violations.append(f"Qubit[{qubit_addr}] on {stmt_info}")
109+
self._validation_errors.append(
110+
f"Qubit[{qubit_addr}] on {stmt_info} in {stmt.source}"
111+
)
110112
seen.add(qubit_addr)
111113

112114
if not violations:
@@ -120,3 +122,7 @@ def run_method(
120122
) -> tuple[ForwardFrame[QubitValidation], QubitValidation]:
121123
self_mt = self.method_self(method)
122124
return self.run_callable(method.code, (self_mt,) + args)
125+
126+
def get_validation_errors(self) -> str:
127+
"""Retrieve collected validation error messages."""
128+
return "\n".join(self._validation_errors)

test/analysis/validation/test_no_cloning.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313

1414
@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz])
15-
def test_control_gate_fail(control_gate: ir.Method[[Qubit, Qubit], Any]):
15+
def test_fail(control_gate: ir.Method[[Qubit, Qubit], Any]):
1616
@squin.kernel
1717
def bad_control():
1818
q = squin.qalloc(1)
@@ -28,7 +28,7 @@ def bad_control():
2828

2929

3030
@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz])
31-
def test_control_gate_conditionals_fail(control_gate: ir.Method[[Qubit, Qubit], Any]):
31+
def test_conditionals_fail(control_gate: ir.Method[[Qubit, Qubit], Any]):
3232
@squin.kernel
3333
def bad_control(cond: bool):
3434
q = squin.qalloc(10)
@@ -44,12 +44,11 @@ def bad_control(cond: bool):
4444
print()
4545
bad_control.print(analysis=frame.entries)
4646
validation_errors = collect_validation_errors(frame, QubitValidation)
47-
# print("Violations:", validation_errors)
4847
assert len(validation_errors) == 2
4948

5049

5150
@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz])
52-
def test_control_gate_parallel_fail(control_gate: ir.Method[[Qubit, Qubit], Any]):
51+
def test_pass(control_gate: ir.Method[[Qubit, Qubit], Any]):
5352
@squin.kernel
5453
def bad_control():
5554
q = squin.qalloc(2)
@@ -64,7 +63,7 @@ def bad_control():
6463
assert len(validation_errors) == 0
6564

6665

67-
def test_control_gate_parallel_pass():
66+
def test_fail_2():
6867
@squin.kernel
6968
def good_kernel():
7069
q = squin.qalloc(2)
@@ -74,7 +73,22 @@ def good_kernel():
7473
validation = NoCloningValidation(good_kernel)
7574
validation.initialize()
7675
frame, _ = validation.run_analysis(good_kernel)
77-
print()
78-
good_kernel.print(analysis=frame.entries)
7976
validation_errors = collect_validation_errors(frame, QubitValidation)
8077
assert len(validation_errors) == 1
78+
print(validation.get_validation_errors())
79+
80+
81+
def test_parallel_fail():
82+
@squin.kernel
83+
def bad_kernel():
84+
q = squin.qalloc(5)
85+
squin.broadcast.cx(IList([q[0], q[1], q[2]]), IList([q[1], q[2], q[3]]))
86+
87+
validation = NoCloningValidation(bad_kernel)
88+
validation.initialize()
89+
frame, _ = validation.run_analysis(bad_kernel)
90+
print()
91+
bad_kernel.print(analysis=frame.entries)
92+
validation_errors = collect_validation_errors(frame, QubitValidation)
93+
assert len(validation_errors) == 2
94+
print(validation.get_validation_errors())

test/analysis/validation/util.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
from typing import TypeVar
1+
from typing import List
22

33
from kirin.analysis import ForwardFrame
44

55
from bloqade.analysis.validation.nocloning.lattice import QubitValidation
66

7-
T = TypeVar("T", bound=QubitValidation)
8-
97

108
def collect_validation_errors(
11-
frame: ForwardFrame[QubitValidation], typ: type[T]
12-
) -> list[T]:
13-
return [
14-
validation_errors
15-
for validation_errors in frame.entries.values()
16-
if isinstance(validation_errors, typ) and len(validation_errors.violations) > 0
17-
]
9+
frame: ForwardFrame[QubitValidation], typ: type[QubitValidation]
10+
) -> List[str]:
11+
"""Collect individual violation strings from all QubitValidation entries of type `typ`."""
12+
violations: List[str] = []
13+
for validation_val in frame.entries.values():
14+
if isinstance(validation_val, typ):
15+
for v in getattr(validation_val, "violations", ()):
16+
violations.append(v)
17+
return violations

0 commit comments

Comments
 (0)