Skip to content

Commit a2564d5

Browse files
committed
fixing tests
1 parent 28d1782 commit a2564d5

File tree

1 file changed

+19
-22
lines changed

1 file changed

+19
-22
lines changed

test/squin/rewrite/test_U3_to_clifford.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,21 @@
33
from kirin import ir
44
from kirin.rewrite import Walk, Chain
55
from kirin.passes.abc import Pass
6-
from kirin.passes.fold import Fold
76
from kirin.rewrite.dce import DeadCodeElimination
8-
from kirin.passes.inline import InlinePass
97

108
from bloqade import squin as sq
119
from bloqade.squin import gate
10+
from bloqade.rewrite.passes import AggressiveUnroll
1211
from bloqade.squin.rewrite.U3_to_clifford import SquinU3ToClifford
1312

1413

1514
class SquinToCliffordTestPass(Pass):
1615

1716
def unsafe_run(self, mt: ir.Method):
1817

19-
rewrite_result = InlinePass(dialects=mt.dialects).fixpoint(mt)
20-
rewrite_result = Fold(dialects=mt.dialects)(mt).join(rewrite_result)
18+
rewrite_result = AggressiveUnroll(mt.dialects).fixpoint(mt)
2119

22-
print("after inline and fold")
20+
print("after unroll")
2321
mt.print()
2422

2523
return (
@@ -71,8 +69,8 @@ def test():
7169

7270
SquinToCliffordTestPass(test.dialects)(test)
7371
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.S)
72+
assert isinstance(get_stmt_at_idx(test, 7), gate.stmts.S)
7473
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.S)
75-
assert isinstance(get_stmt_at_idx(test, 13), gate.stmts.S)
7674
S_stmts = filter_statements_by_type(test, (gate.stmts.S,))
7775
# Should be normal S gates, not adjoint/dagger
7876
assert not S_stmts[0].adjoint
@@ -95,11 +93,10 @@ def test():
9593
sq.u3(theta=0.0, phi=0.5 * math.tau, lam=0.0, qubit=q[3])
9694

9795
SquinToCliffordTestPass(test.dialects)(test)
98-
9996
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.Z)
97+
assert isinstance(get_stmt_at_idx(test, 7), gate.stmts.Z)
10098
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.Z)
101-
assert isinstance(get_stmt_at_idx(test, 13), gate.stmts.Z)
102-
assert isinstance(get_stmt_at_idx(test, 17), gate.stmts.Z)
99+
assert isinstance(get_stmt_at_idx(test, 11), gate.stmts.Z)
103100

104101

105102
def test_sdag():
@@ -119,10 +116,10 @@ def test():
119116
test.print()
120117

121118
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.S)
119+
assert isinstance(get_stmt_at_idx(test, 7), gate.stmts.S)
122120
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.S)
123-
assert isinstance(get_stmt_at_idx(test, 13), gate.stmts.S)
124-
assert isinstance(get_stmt_at_idx(test, 17), gate.stmts.S)
125-
assert isinstance(get_stmt_at_idx(test, 21), gate.stmts.S)
121+
assert isinstance(get_stmt_at_idx(test, 11), gate.stmts.S)
122+
assert isinstance(get_stmt_at_idx(test, 12), gate.stmts.S)
126123

127124
sdag_stmts = filter_statements_by_type(test, (gate.stmts.S,))
128125
for sdag_stmt in sdag_stmts:
@@ -156,7 +153,7 @@ def test():
156153
SquinToCliffordTestPass(test.dialects)(test)
157154

158155
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.SqrtY)
159-
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.SqrtY)
156+
assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.SqrtY)
160157
sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,))
161158
assert not sqrt_y_stmts[0].adjoint
162159
assert not sqrt_y_stmts[1].adjoint
@@ -178,8 +175,8 @@ def test():
178175

179176
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.S)
180177
assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.SqrtY)
181-
assert isinstance(get_stmt_at_idx(test, 10), gate.stmts.S)
182-
assert isinstance(get_stmt_at_idx(test, 11), gate.stmts.SqrtY)
178+
assert isinstance(get_stmt_at_idx(test, 8), gate.stmts.S)
179+
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.SqrtY)
183180

184181
s_stmts = filter_statements_by_type(test, (gate.stmts.S,))
185182
sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,))
@@ -202,7 +199,7 @@ def test():
202199

203200
SquinToCliffordTestPass(test.dialects)(test)
204201
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.H)
205-
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.H)
202+
assert isinstance(get_stmt_at_idx(test, 7), gate.stmts.H)
206203

207204

208205
def test_sdg_sqrt_y():
@@ -221,8 +218,8 @@ def test():
221218
SquinToCliffordTestPass(test.dialects)(test)
222219
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.S)
223220
assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.SqrtY)
224-
assert isinstance(get_stmt_at_idx(test, 10), gate.stmts.S)
225-
assert isinstance(get_stmt_at_idx(test, 11), gate.stmts.SqrtY)
221+
assert isinstance(get_stmt_at_idx(test, 8), gate.stmts.S)
222+
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.SqrtY)
226223

227224
s_stmts = filter_statements_by_type(test, (gate.stmts.S,))
228225
sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,))
@@ -251,8 +248,8 @@ def test():
251248
test.print()
252249
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.SqrtY)
253250
assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.S)
254-
assert isinstance(get_stmt_at_idx(test, 10), gate.stmts.SqrtY)
255-
assert isinstance(get_stmt_at_idx(test, 11), gate.stmts.S)
251+
assert isinstance(get_stmt_at_idx(test, 8), gate.stmts.SqrtY)
252+
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.S)
256253

257254
sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,))
258255
s_stmts = filter_statements_by_type(test, (gate.stmts.S,))
@@ -390,8 +387,8 @@ def test():
390387

391388
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.SqrtY)
392389
assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.Z)
393-
assert isinstance(get_stmt_at_idx(test, 10), gate.stmts.SqrtY)
394-
assert isinstance(get_stmt_at_idx(test, 11), gate.stmts.Z)
390+
assert isinstance(get_stmt_at_idx(test, 8), gate.stmts.SqrtY)
391+
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.Z)
395392

396393
sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,))
397394
for sqrt_y_stmt in sqrt_y_stmts:

0 commit comments

Comments
 (0)