33from kirin import ir
44from kirin .rewrite import Walk , Chain
55from kirin .passes .abc import Pass
6- from kirin .passes .fold import Fold
76from kirin .rewrite .dce import DeadCodeElimination
8- from kirin .passes .inline import InlinePass
97
108from bloqade import squin as sq
119from bloqade .squin import gate
10+ from bloqade .rewrite .passes import AggressiveUnroll
1211from bloqade .squin .rewrite .U3_to_clifford import SquinU3ToClifford
1312
1413
1514class SquinToCliffordTestPass (Pass ):
1615
1716 def unsafe_run (self , mt : ir .Method ):
1817
19- rewrite_result = InlinePass (dialects = mt .dialects ).fixpoint (mt )
20- rewrite_result = Fold (dialects = mt .dialects )(mt ).join (rewrite_result )
18+ rewrite_result = AggressiveUnroll (mt .dialects ).fixpoint (mt )
2119
22- print ("after inline and fold " )
20+ print ("after unroll " )
2321 mt .print ()
2422
2523 return (
@@ -71,8 +69,8 @@ def test():
7169
7270 SquinToCliffordTestPass (test .dialects )(test )
7371 assert isinstance (get_stmt_at_idx (test , 5 ), gate .stmts .S )
72+ assert isinstance (get_stmt_at_idx (test , 7 ), gate .stmts .S )
7473 assert isinstance (get_stmt_at_idx (test , 9 ), gate .stmts .S )
75- assert isinstance (get_stmt_at_idx (test , 13 ), gate .stmts .S )
7674 S_stmts = filter_statements_by_type (test , (gate .stmts .S ,))
7775 # Should be normal S gates, not adjoint/dagger
7876 assert not S_stmts [0 ].adjoint
@@ -95,11 +93,10 @@ def test():
9593 sq .u3 (theta = 0.0 , phi = 0.5 * math .tau , lam = 0.0 , qubit = q [3 ])
9694
9795 SquinToCliffordTestPass (test .dialects )(test )
98-
9996 assert isinstance (get_stmt_at_idx (test , 5 ), gate .stmts .Z )
97+ assert isinstance (get_stmt_at_idx (test , 7 ), gate .stmts .Z )
10098 assert isinstance (get_stmt_at_idx (test , 9 ), gate .stmts .Z )
101- assert isinstance (get_stmt_at_idx (test , 13 ), gate .stmts .Z )
102- assert isinstance (get_stmt_at_idx (test , 17 ), gate .stmts .Z )
99+ assert isinstance (get_stmt_at_idx (test , 11 ), gate .stmts .Z )
103100
104101
105102def test_sdag ():
@@ -119,10 +116,10 @@ def test():
119116 test .print ()
120117
121118 assert isinstance (get_stmt_at_idx (test , 5 ), gate .stmts .S )
119+ assert isinstance (get_stmt_at_idx (test , 7 ), gate .stmts .S )
122120 assert isinstance (get_stmt_at_idx (test , 9 ), gate .stmts .S )
123- assert isinstance (get_stmt_at_idx (test , 13 ), gate .stmts .S )
124- assert isinstance (get_stmt_at_idx (test , 17 ), gate .stmts .S )
125- assert isinstance (get_stmt_at_idx (test , 21 ), gate .stmts .S )
121+ assert isinstance (get_stmt_at_idx (test , 11 ), gate .stmts .S )
122+ assert isinstance (get_stmt_at_idx (test , 12 ), gate .stmts .S )
126123
127124 sdag_stmts = filter_statements_by_type (test , (gate .stmts .S ,))
128125 for sdag_stmt in sdag_stmts :
@@ -156,7 +153,7 @@ def test():
156153 SquinToCliffordTestPass (test .dialects )(test )
157154
158155 assert isinstance (get_stmt_at_idx (test , 5 ), gate .stmts .SqrtY )
159- assert isinstance (get_stmt_at_idx (test , 9 ), gate .stmts .SqrtY )
156+ assert isinstance (get_stmt_at_idx (test , 6 ), gate .stmts .SqrtY )
160157 sqrt_y_stmts = filter_statements_by_type (test , (gate .stmts .SqrtY ,))
161158 assert not sqrt_y_stmts [0 ].adjoint
162159 assert not sqrt_y_stmts [1 ].adjoint
@@ -178,8 +175,8 @@ def test():
178175
179176 assert isinstance (get_stmt_at_idx (test , 5 ), gate .stmts .S )
180177 assert isinstance (get_stmt_at_idx (test , 6 ), gate .stmts .SqrtY )
181- assert isinstance (get_stmt_at_idx (test , 10 ), gate .stmts .S )
182- assert isinstance (get_stmt_at_idx (test , 11 ), gate .stmts .SqrtY )
178+ assert isinstance (get_stmt_at_idx (test , 8 ), gate .stmts .S )
179+ assert isinstance (get_stmt_at_idx (test , 9 ), gate .stmts .SqrtY )
183180
184181 s_stmts = filter_statements_by_type (test , (gate .stmts .S ,))
185182 sqrt_y_stmts = filter_statements_by_type (test , (gate .stmts .SqrtY ,))
@@ -202,7 +199,7 @@ def test():
202199
203200 SquinToCliffordTestPass (test .dialects )(test )
204201 assert isinstance (get_stmt_at_idx (test , 5 ), gate .stmts .H )
205- assert isinstance (get_stmt_at_idx (test , 9 ), gate .stmts .H )
202+ assert isinstance (get_stmt_at_idx (test , 7 ), gate .stmts .H )
206203
207204
208205def test_sdg_sqrt_y ():
@@ -221,8 +218,8 @@ def test():
221218 SquinToCliffordTestPass (test .dialects )(test )
222219 assert isinstance (get_stmt_at_idx (test , 5 ), gate .stmts .S )
223220 assert isinstance (get_stmt_at_idx (test , 6 ), gate .stmts .SqrtY )
224- assert isinstance (get_stmt_at_idx (test , 10 ), gate .stmts .S )
225- assert isinstance (get_stmt_at_idx (test , 11 ), gate .stmts .SqrtY )
221+ assert isinstance (get_stmt_at_idx (test , 8 ), gate .stmts .S )
222+ assert isinstance (get_stmt_at_idx (test , 9 ), gate .stmts .SqrtY )
226223
227224 s_stmts = filter_statements_by_type (test , (gate .stmts .S ,))
228225 sqrt_y_stmts = filter_statements_by_type (test , (gate .stmts .SqrtY ,))
@@ -251,8 +248,8 @@ def test():
251248 test .print ()
252249 assert isinstance (get_stmt_at_idx (test , 5 ), gate .stmts .SqrtY )
253250 assert isinstance (get_stmt_at_idx (test , 6 ), gate .stmts .S )
254- assert isinstance (get_stmt_at_idx (test , 10 ), gate .stmts .SqrtY )
255- assert isinstance (get_stmt_at_idx (test , 11 ), gate .stmts .S )
251+ assert isinstance (get_stmt_at_idx (test , 8 ), gate .stmts .SqrtY )
252+ assert isinstance (get_stmt_at_idx (test , 9 ), gate .stmts .S )
256253
257254 sqrt_y_stmts = filter_statements_by_type (test , (gate .stmts .SqrtY ,))
258255 s_stmts = filter_statements_by_type (test , (gate .stmts .S ,))
@@ -390,8 +387,8 @@ def test():
390387
391388 assert isinstance (get_stmt_at_idx (test , 5 ), gate .stmts .SqrtY )
392389 assert isinstance (get_stmt_at_idx (test , 6 ), gate .stmts .Z )
393- assert isinstance (get_stmt_at_idx (test , 10 ), gate .stmts .SqrtY )
394- assert isinstance (get_stmt_at_idx (test , 11 ), gate .stmts .Z )
390+ assert isinstance (get_stmt_at_idx (test , 8 ), gate .stmts .SqrtY )
391+ assert isinstance (get_stmt_at_idx (test , 9 ), gate .stmts .Z )
395392
396393 sqrt_y_stmts = filter_statements_by_type (test , (gate .stmts .SqrtY ,))
397394 for sqrt_y_stmt in sqrt_y_stmts :
0 commit comments