Skip to content

Commit 9da3dcf

Browse files
committed
Fix tests
1 parent eee3cda commit 9da3dcf

File tree

4 files changed

+78
-100
lines changed

4 files changed

+78
-100
lines changed

src/bloqade/squin/rewrite/U3_to_clifford.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def rewrite_U3(self, node: gate.stmts.U3) -> RewriteResult:
114114
for gate_stmt in gates:
115115
if gate_stmt is Sdag:
116116
new_stmt = gate.stmts.S(adjoint=True, qubits=node.qubits)
117-
if gate_stmt is SqrtXdag:
117+
elif gate_stmt is SqrtXdag:
118118
new_stmt = gate.stmts.SqrtX(adjoint=True, qubits=node.qubits)
119-
if gate_stmt is SqrtYdag:
119+
elif gate_stmt is SqrtYdag:
120120
new_stmt = gate.stmts.SqrtY(adjoint=True, qubits=node.qubits)
121121
else:
122122
new_stmt = gate_stmt(qubits=node.qubits)

test/squin/rewrite/test_U3_to_clifford.py

Lines changed: 71 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def test():
231231
assert not sqrt_y_stmt.adjoint
232232

233233

234-
def test_sqrt_y_s():
234+
def test_s_sqrt_x_dag():
235235

236236
@sq.kernel
237237
def test():
@@ -246,22 +246,22 @@ def test():
246246

247247
SquinToCliffordTestPass(test.dialects)(test)
248248
test.print()
249-
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.SqrtY)
250-
assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.S)
251-
assert isinstance(get_stmt_at_idx(test, 8), gate.stmts.SqrtY)
252-
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.S)
249+
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.S)
250+
assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.SqrtX)
251+
assert isinstance(get_stmt_at_idx(test, 8), gate.stmts.S)
252+
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.SqrtX)
253253

254-
sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,))
254+
sqrt_x_stmts = filter_statements_by_type(test, (gate.stmts.SqrtX,))
255255
s_stmts = filter_statements_by_type(test, (gate.stmts.S,))
256256

257-
for sqrt_y_stmt in sqrt_y_stmts:
258-
assert not sqrt_y_stmt.adjoint
257+
for sqrt_x_stmt in sqrt_x_stmts:
258+
assert sqrt_x_stmt.adjoint
259259

260260
for s_stmt in s_stmts:
261261
assert not s_stmt.adjoint
262262

263263

264-
def test_s_sqrt_y_s():
264+
def test_z_sqrt_x_dag():
265265

266266
@sq.kernel
267267
def test():
@@ -277,29 +277,29 @@ def test():
277277
SquinToCliffordTestPass(test.dialects)(test)
278278

279279
s_stmts = filter_statements_by_type(test, (gate.stmts.S,))
280-
sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,))
280+
sqrt_x_stmts = filter_statements_by_type(test, (gate.stmts.SqrtX,))
281281

282282
# Should be S, SqrtY, S for each op
283283
assert [
284284
type(stmt)
285-
for stmt in filter_statements_by_type(test, (gate.stmts.S, gate.stmts.SqrtY))
285+
for stmt in filter_statements_by_type(
286+
test, (gate.stmts.S, gate.stmts.Z, gate.stmts.SqrtX, gate.stmts.SqrtY)
287+
)
286288
] == [
287-
gate.stmts.S,
288-
gate.stmts.SqrtY,
289-
gate.stmts.S,
290-
gate.stmts.S,
291-
gate.stmts.SqrtY,
292-
gate.stmts.S,
289+
gate.stmts.Z,
290+
gate.stmts.SqrtX,
291+
gate.stmts.Z,
292+
gate.stmts.SqrtX,
293293
]
294294

295295
# Check adjoint property
296296
for s_stmt in s_stmts:
297297
assert not s_stmt.adjoint
298-
for sqrt_y_stmt in sqrt_y_stmts:
299-
assert not sqrt_y_stmt.adjoint
298+
for sqrt_x_stmt in sqrt_x_stmts:
299+
assert sqrt_x_stmt.adjoint
300300

301301

302-
def test_z_sqrt_y_s():
302+
def test_s_dag_sqrt_x_dag():
303303

304304
@sq.kernel
305305
def test():
@@ -316,25 +316,23 @@ def test():
316316
test.print()
317317

318318
relevant_stmts = filter_statements_by_type(
319-
test, (gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S)
319+
test, (gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S, gate.stmts.SqrtX)
320320
)
321321

322322
expected_types = [
323-
gate.stmts.Z,
324-
gate.stmts.SqrtY,
325323
gate.stmts.S,
326-
gate.stmts.Z,
327-
gate.stmts.SqrtY,
324+
gate.stmts.SqrtX,
328325
gate.stmts.S,
326+
gate.stmts.SqrtX,
329327
]
330328
assert [type(stmt) for stmt in relevant_stmts] == expected_types
331329

332330
for relevant_stmt in relevant_stmts:
333331
if type(relevant_stmt) is not gate.stmts.Z:
334-
assert not relevant_stmt.adjoint
332+
assert relevant_stmt.adjoint
335333

336334

337-
def test_sdg_sqrt_y_s():
335+
def test_sqrt_x_dag():
338336

339337
@sq.kernel
340338
def test():
@@ -349,29 +347,26 @@ def test():
349347

350348
SquinToCliffordTestPass(test.dialects)(test)
351349

352-
relevant_stmts = filter_statements_by_type(test, (gate.stmts.S, gate.stmts.SqrtY))
350+
relevant_stmts = filter_statements_by_type(
351+
test, (gate.stmts.S, gate.stmts.SqrtY, gate.stmts.SqrtX)
352+
)
353353

354354
# Should be Sdg, SqrtY, S for each op
355355
assert [type(stmt) for stmt in relevant_stmts] == [
356-
gate.stmts.S,
357-
gate.stmts.SqrtY,
358-
gate.stmts.S,
359-
gate.stmts.S,
360-
gate.stmts.SqrtY,
361-
gate.stmts.S,
356+
gate.stmts.SqrtX,
357+
gate.stmts.SqrtX,
362358
]
363359

364360
# Check adjoint property: the first S in each group should be adjoint
365361
s_stmts = filter_statements_by_type(test, (gate.stmts.S,))
366-
sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,))
362+
sqrt_x_stmts = filter_statements_by_type(test, (gate.stmts.SqrtX,))
367363

368-
assert s_stmts[0].adjoint
369-
assert s_stmts[2].adjoint
370-
for sqrt_y_stmt in sqrt_y_stmts:
371-
assert not sqrt_y_stmt.adjoint
364+
assert not s_stmts
365+
for sqrt_x_stmt in sqrt_x_stmts:
366+
assert sqrt_x_stmt.adjoint
372367

373368

374-
def test_sqrt_y_z():
369+
def test_z_sqrt_y_dag():
375370

376371
@sq.kernel
377372
def test():
@@ -385,17 +380,17 @@ def test():
385380
SquinToCliffordTestPass(test.dialects)(test)
386381
test.print()
387382

388-
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.SqrtY)
389-
assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.Z)
390-
assert isinstance(get_stmt_at_idx(test, 8), gate.stmts.SqrtY)
391-
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.Z)
383+
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.Z)
384+
assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.SqrtY)
385+
assert isinstance(get_stmt_at_idx(test, 8), gate.stmts.Z)
386+
assert isinstance(get_stmt_at_idx(test, 9), gate.stmts.SqrtY)
392387

393388
sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,))
394389
for sqrt_y_stmt in sqrt_y_stmts:
395-
assert not sqrt_y_stmt.adjoint
390+
assert sqrt_y_stmt.adjoint
396391

397392

398-
def test_s_sqrt_y_z():
393+
def test_s_dag_sqrt_y_dag():
399394

400395
@sq.kernel
401396
def test():
@@ -417,18 +412,16 @@ def test():
417412
assert [type(stmt) for stmt in relevant_stmts] == [
418413
gate.stmts.S,
419414
gate.stmts.SqrtY,
420-
gate.stmts.Z,
421415
gate.stmts.S,
422416
gate.stmts.SqrtY,
423-
gate.stmts.Z,
424417
]
425418

426419
for stmt in relevant_stmts:
427420
if type(stmt) is not gate.stmts.Z:
428-
assert not stmt.adjoint
421+
assert stmt.adjoint
429422

430423

431-
def test_z_sqrt_y_z():
424+
def test_sqrt_y_dag():
432425

433426
@sq.kernel
434427
def test():
@@ -444,21 +437,17 @@ def test():
444437
relevant_stmts = filter_statements_by_type(test, (gate.stmts.Z, gate.stmts.SqrtY))
445438

446439
expected_types = [
447-
gate.stmts.Z,
448440
gate.stmts.SqrtY,
449-
gate.stmts.Z,
450-
gate.stmts.Z,
451441
gate.stmts.SqrtY,
452-
gate.stmts.Z,
453442
]
454443
assert [type(stmt) for stmt in relevant_stmts] == expected_types
455444

456445
for stmt in relevant_stmts:
457446
if type(stmt) is gate.stmts.SqrtY:
458-
assert not stmt.adjoint
447+
assert stmt.adjoint
459448

460449

461-
def test_sdg_sqrt_y_z():
450+
def test_s_sqrt_y_dag():
462451

463452
@sq.kernel
464453
def test():
@@ -481,24 +470,22 @@ def test():
481470
assert [type(stmt) for stmt in relevant_stmts] == [
482471
gate.stmts.S,
483472
gate.stmts.SqrtY,
484-
gate.stmts.Z,
485473
gate.stmts.S,
486474
gate.stmts.SqrtY,
487-
gate.stmts.Z,
488475
]
489476

490477
# Check adjoint property: Sdag should be adjoint, SqrtY and Z should not
491478
s_stmts = filter_statements_by_type(test, (gate.stmts.S,))
492479
sqrt_y_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY,))
493480

494481
for s_stmt in s_stmts:
495-
assert s_stmt.adjoint
482+
assert not s_stmt.adjoint
496483

497484
for sqrt_y_stmt in sqrt_y_stmts:
498-
assert not sqrt_y_stmt.adjoint
485+
assert sqrt_y_stmt.adjoint
499486

500487

501-
def test_sqrt_y_sdg():
488+
def test_s_dag_sqrt_x():
502489

503490
@sq.kernel
504491
def test():
@@ -510,18 +497,20 @@ def test():
510497

511498
SquinToCliffordTestPass(test.dialects)(test)
512499

513-
relevant_stmts = filter_statements_by_type(test, (gate.stmts.SqrtY, gate.stmts.S))
500+
relevant_stmts = filter_statements_by_type(
501+
test, (gate.stmts.SqrtY, gate.stmts.SqrtX, gate.stmts.S)
502+
)
514503
# Check for SqrtY followed by S (adjoint property can be checked if needed)
515504
assert [type(stmt) for stmt in relevant_stmts] == [
516-
gate.stmts.SqrtY,
517505
gate.stmts.S,
506+
gate.stmts.SqrtX,
518507
]
519508

520-
assert not relevant_stmts[0].adjoint
521-
assert relevant_stmts[1].adjoint
509+
assert relevant_stmts[0].adjoint
510+
assert not relevant_stmts[1].adjoint
522511

523512

524-
def test_s_sqrt_y_sdg():
513+
def test_sqrt_x():
525514

526515
@sq.kernel
527516
def test():
@@ -532,20 +521,18 @@ def test():
532521
)
533522

534523
SquinToCliffordTestPass(test.dialects)(test)
535-
relevant_stmts = filter_statements_by_type(test, (gate.stmts.S, gate.stmts.SqrtY))
524+
relevant_stmts = filter_statements_by_type(
525+
test, (gate.stmts.S, gate.stmts.SqrtY, gate.stmts.SqrtX)
526+
)
536527

537528
assert [type(stmt) for stmt in relevant_stmts] == [
538-
gate.stmts.S,
539-
gate.stmts.SqrtY,
540-
gate.stmts.S,
529+
gate.stmts.SqrtX,
541530
]
542531
# The last S should be adjoint
543532
assert not relevant_stmts[0].adjoint
544-
assert not relevant_stmts[1].adjoint
545-
assert relevant_stmts[2].adjoint
546533

547534

548-
def test_z_sqrt_y_sdg():
535+
def test_s_sqrt_x():
549536

550537
@sq.kernel
551538
def test():
@@ -558,19 +545,17 @@ def test():
558545
SquinToCliffordTestPass(test.dialects)(test)
559546

560547
relevant_stmts = filter_statements_by_type(
561-
test, (gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S)
548+
test, (gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S, gate.stmts.SqrtX)
562549
)
563-
# Should be Z, SqrtY, S (adjoint)
564550
assert [type(stmt) for stmt in relevant_stmts] == [
565-
gate.stmts.Z,
566-
gate.stmts.SqrtY,
567551
gate.stmts.S,
552+
gate.stmts.SqrtX,
568553
]
554+
assert not relevant_stmts[0].adjoint
569555
assert not relevant_stmts[1].adjoint
570-
assert relevant_stmts[2].adjoint
571556

572557

573-
def test_sdg_sqrt_y_sdg():
558+
def test_z_sqrt_x():
574559

575560
@sq.kernel
576561
def test():
@@ -582,18 +567,17 @@ def test():
582567

583568
SquinToCliffordTestPass(test.dialects)(test)
584569

585-
relevant_stmts = filter_statements_by_type(test, (gate.stmts.S, gate.stmts.SqrtY))
570+
relevant_stmts = filter_statements_by_type(
571+
test, (gate.stmts.S, gate.stmts.SqrtY, gate.stmts.Z, gate.stmts.SqrtX)
572+
)
586573

587574
# Should be Sdag, SqrtY, Sdag for the op
588575
assert [type(stmt) for stmt in relevant_stmts] == [
589-
gate.stmts.S,
590-
gate.stmts.SqrtY,
591-
gate.stmts.S,
576+
gate.stmts.Z,
577+
gate.stmts.SqrtX,
592578
]
593579
# The first and last S should be adjoint, SqrtY should not
594-
assert relevant_stmts[0].adjoint
595580
assert not relevant_stmts[1].adjoint
596-
assert relevant_stmts[2].adjoint
597581

598582

599583
def test_y():
@@ -625,7 +609,7 @@ def test():
625609
assert not s_stmt.adjoint
626610

627611

628-
def test_z_y():
612+
def test_x():
629613

630614
@sq.kernel
631615
def test():
@@ -634,8 +618,7 @@ def test():
634618
sq.u3(theta=0.5 * math.tau, phi=0.0 * math.tau, lam=0.5 * math.tau, qubit=q[0])
635619

636620
SquinToCliffordTestPass(test.dialects)(test)
637-
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.Z)
638-
assert isinstance(get_stmt_at_idx(test, 6), gate.stmts.Y)
621+
assert isinstance(get_stmt_at_idx(test, 5), gate.stmts.X)
639622

640623

641624
def test_sdg_y():
Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11

2-
S 0
3-
SQRT_Y 0
4-
S 0
5-
S_DAG 0
6-
SQRT_Y 0
7-
S 0
8-
S 0
9-
SQRT_Y 0
10-
S_DAG 0
2+
Z 0
3+
SQRT_X_DAG 0
4+
SQRT_X_DAG 0
5+
SQRT_X 0

0 commit comments

Comments
 (0)