Skip to content

Commit 2e155d9

Browse files
committed
fix import errors
1 parent 6429148 commit 2e155d9

File tree

2 files changed

+29
-41
lines changed

2 files changed

+29
-41
lines changed

test/analysis/validation/test_no_cloning.py

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
1-
from typing import Any
1+
from typing import Any, List, TypeVar
22

33
import pytest
4-
from .util import collect_may_errors, collect_must_errors
54
from kirin import ir
5+
from kirin.analysis import ForwardFrame
66
from kirin.dialects.ilist.runtime import IList
77

88
from bloqade import squin
99
from bloqade.types import Qubit
10+
from bloqade.analysis.validation.nocloning.lattice import May, Must, QubitValidation
1011
from bloqade.analysis.validation.nocloning.analysis import NoCloningValidation
1112

13+
T = TypeVar("T", bound=Must | May)
14+
15+
16+
def collect_errors(frame: ForwardFrame[QubitValidation], typ: type[T]) -> List[str]:
17+
"""Collect individual violation strings from all QubitValidation entries of type `typ`."""
18+
violations: List[str] = []
19+
for validation_val in frame.entries.values():
20+
if isinstance(validation_val, typ):
21+
for v in validation_val.violations:
22+
violations.append(v)
23+
return violations
24+
1225

1326
@pytest.mark.parametrize("control_gate", [squin.cx, squin.cy, squin.cz])
1427
def test_fail(control_gate: ir.Method[[Qubit, Qubit], Any]):
@@ -22,8 +35,8 @@ def bad_control():
2235
frame, _ = validation.run_analysis(bad_control)
2336
print()
2437
bad_control.print(analysis=frame.entries)
25-
must_errors = collect_must_errors(frame)
26-
may_errors = collect_may_errors(frame)
38+
must_errors = collect_errors(frame, Must)
39+
may_errors = collect_errors(frame, May)
2740
assert len(must_errors) == 1
2841
assert len(may_errors) == 0
2942
with pytest.raises(Exception):
@@ -46,8 +59,8 @@ def bad_control(cond: bool):
4659
frame, _ = validation.run_analysis(bad_control)
4760
print()
4861
bad_control.print(analysis=frame.entries)
49-
must_errors = collect_must_errors(frame)
50-
may_errors = collect_may_errors(frame)
62+
must_errors = collect_errors(frame, Must)
63+
may_errors = collect_errors(frame, May)
5164
assert len(must_errors) == 2
5265
assert len(may_errors) == 0
5366
with pytest.raises(Exception):
@@ -69,8 +82,8 @@ def test():
6982
frame, _ = validation.run_analysis(test)
7083
print()
7184
test.print(analysis=frame.entries)
72-
must_errors = collect_must_errors(frame)
73-
may_errors = collect_may_errors(frame)
85+
must_errors = collect_errors(frame, Must)
86+
may_errors = collect_errors(frame, May)
7487
assert len(must_errors) == 0
7588
assert len(may_errors) == 0
7689

@@ -86,8 +99,8 @@ def good_kernel():
8699
validation = NoCloningValidation(good_kernel)
87100
validation.initialize()
88101
frame, _ = validation.run_analysis(good_kernel)
89-
must_errors = collect_must_errors(frame)
90-
may_errors = collect_may_errors(frame)
102+
must_errors = collect_errors(frame, Must)
103+
may_errors = collect_errors(frame, May)
91104
assert len(must_errors) == 1
92105
assert len(may_errors) == 0
93106
with pytest.raises(Exception):
@@ -105,8 +118,8 @@ def bad_kernel():
105118
frame, _ = validation.run_analysis(bad_kernel)
106119
print()
107120
bad_kernel.print(analysis=frame.entries)
108-
must_errors = collect_must_errors(frame)
109-
may_errors = collect_may_errors(frame)
121+
must_errors = collect_errors(frame, Must)
122+
may_errors = collect_errors(frame, May)
110123
assert len(must_errors) == 2
111124
assert len(may_errors) == 0
112125
with pytest.raises(Exception):
@@ -124,8 +137,8 @@ def bad_kernel(a: int, b: int):
124137
frame, _ = validation.run_analysis(bad_kernel)
125138
print()
126139
bad_kernel.print(analysis=frame.entries)
127-
must_errors = collect_must_errors(frame)
128-
may_errors = collect_may_errors(frame)
140+
must_errors = collect_errors(frame, Must)
141+
may_errors = collect_errors(frame, May)
129142
assert len(must_errors) == 0
130143
assert len(may_errors) == 1
131144
with pytest.raises(Exception):
@@ -143,8 +156,8 @@ def bad_kernel(a: IList):
143156
frame, _ = validation.run_analysis(bad_kernel)
144157
print()
145158
bad_kernel.print(analysis=frame.entries)
146-
must_errors = collect_must_errors(frame)
147-
may_errors = collect_may_errors(frame)
159+
must_errors = collect_errors(frame, Must)
160+
may_errors = collect_errors(frame, May)
148161
assert len(must_errors) == 0
149162
assert len(may_errors) == 1
150163
with pytest.raises(Exception):

test/analysis/validation/util.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)