1- from typing import Any
1+ from typing import Any , List , TypeVar
22
33import pytest
4- from .util import collect_may_errors , collect_must_errors
54from kirin import ir
5+ from kirin .analysis import ForwardFrame
66from kirin .dialects .ilist .runtime import IList
77
88from bloqade import squin
99from bloqade .types import Qubit
10+ from bloqade .analysis .validation .nocloning .lattice import May , Must , QubitValidation
1011from bloqade .analysis .validation .nocloning .analysis import NoCloningValidation
1112
13+ T = TypeVar ("T" , bound = Must | May )
14+
15+
16+ def collect_errors (frame : ForwardFrame [QubitValidation ], typ : type [T ]) -> List [str ]:
17+ """Collect individual violation strings from all QubitValidation entries of type `typ`."""
18+ violations : List [str ] = []
19+ for validation_val in frame .entries .values ():
20+ if isinstance (validation_val , typ ):
21+ for v in validation_val .violations :
22+ violations .append (v )
23+ return violations
24+
1225
1326@pytest .mark .parametrize ("control_gate" , [squin .cx , squin .cy , squin .cz ])
1427def test_fail (control_gate : ir .Method [[Qubit , Qubit ], Any ]):
@@ -22,8 +35,8 @@ def bad_control():
2235 frame , _ = validation .run_analysis (bad_control )
2336 print ()
2437 bad_control .print (analysis = frame .entries )
25- must_errors = collect_must_errors (frame )
26- may_errors = collect_may_errors (frame )
38+ must_errors = collect_errors (frame , Must )
39+ may_errors = collect_errors (frame , May )
2740 assert len (must_errors ) == 1
2841 assert len (may_errors ) == 0
2942 with pytest .raises (Exception ):
@@ -46,8 +59,8 @@ def bad_control(cond: bool):
4659 frame , _ = validation .run_analysis (bad_control )
4760 print ()
4861 bad_control .print (analysis = frame .entries )
49- must_errors = collect_must_errors (frame )
50- may_errors = collect_may_errors (frame )
62+ must_errors = collect_errors (frame , Must )
63+ may_errors = collect_errors (frame , May )
5164 assert len (must_errors ) == 2
5265 assert len (may_errors ) == 0
5366 with pytest .raises (Exception ):
@@ -69,8 +82,8 @@ def test():
6982 frame , _ = validation .run_analysis (test )
7083 print ()
7184 test .print (analysis = frame .entries )
72- must_errors = collect_must_errors (frame )
73- may_errors = collect_may_errors (frame )
85+ must_errors = collect_errors (frame , Must )
86+ may_errors = collect_errors (frame , May )
7487 assert len (must_errors ) == 0
7588 assert len (may_errors ) == 0
7689
@@ -86,8 +99,8 @@ def good_kernel():
8699 validation = NoCloningValidation (good_kernel )
87100 validation .initialize ()
88101 frame , _ = validation .run_analysis (good_kernel )
89- must_errors = collect_must_errors (frame )
90- may_errors = collect_may_errors (frame )
102+ must_errors = collect_errors (frame , Must )
103+ may_errors = collect_errors (frame , May )
91104 assert len (must_errors ) == 1
92105 assert len (may_errors ) == 0
93106 with pytest .raises (Exception ):
@@ -105,8 +118,8 @@ def bad_kernel():
105118 frame , _ = validation .run_analysis (bad_kernel )
106119 print ()
107120 bad_kernel .print (analysis = frame .entries )
108- must_errors = collect_must_errors (frame )
109- may_errors = collect_may_errors (frame )
121+ must_errors = collect_errors (frame , Must )
122+ may_errors = collect_errors (frame , May )
110123 assert len (must_errors ) == 2
111124 assert len (may_errors ) == 0
112125 with pytest .raises (Exception ):
@@ -124,8 +137,8 @@ def bad_kernel(a: int, b: int):
124137 frame , _ = validation .run_analysis (bad_kernel )
125138 print ()
126139 bad_kernel .print (analysis = frame .entries )
127- must_errors = collect_must_errors (frame )
128- may_errors = collect_may_errors (frame )
140+ must_errors = collect_errors (frame , Must )
141+ may_errors = collect_errors (frame , May )
129142 assert len (must_errors ) == 0
130143 assert len (may_errors ) == 1
131144 with pytest .raises (Exception ):
@@ -143,8 +156,8 @@ def bad_kernel(a: IList):
143156 frame , _ = validation .run_analysis (bad_kernel )
144157 print ()
145158 bad_kernel .print (analysis = frame .entries )
146- must_errors = collect_must_errors (frame )
147- may_errors = collect_may_errors (frame )
159+ must_errors = collect_errors (frame , Must )
160+ may_errors = collect_errors (frame , May )
148161 assert len (must_errors ) == 0
149162 assert len (may_errors ) == 1
150163 with pytest .raises (Exception ):
0 commit comments