Skip to content

Commit 5cb47e1

Browse files
committed
refactor tests to check equivilance and fixing more potential bugs
1 parent 16d2460 commit 5cb47e1

File tree

3 files changed

+11
-7
lines changed

3 files changed

+11
-7
lines changed

src/bloqade/native/stdlib/broadcast.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,9 @@ def u3(theta: float, phi: float, lam: float, qubits: ilist.IList[qubit.Qubit, An
225225
lam (float): Z rotations in decomposition (radians).
226226
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
227227
"""
228-
rz(lam, qubits)
229-
ry(theta, qubits)
230-
rz(phi, qubits)
228+
_u3_turns(
229+
_radian_to_turn(theta), _radian_to_turn(phi), _radian_to_turn(lam), qubits
230+
)
231231

232232

233233
N = TypeVar("N")

src/bloqade/rewrite/passes/callgraph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
9898
mt_map = {}
9999

100100
cg = CallGraph(mt)
101-
all_methods = set([mt])
102-
all_methods.update(cg.edges.keys())
101+
all_methods = set(cg.edges.keys())
102+
all_methods.add(mt)
103103
for original_mt in all_methods:
104104
if original_mt is mt:
105105
new_mt = original_mt

test/native/upstream/test_squin2native.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def ghz_native(angle: float):
7676

7777
AggressiveUnroll(ghz_native.dialects).fixpoint(ghz_native)
7878
test_utils.assert_nodes(
79-
ghz_native_rewrite.callable_region, ghz_native.callable_region
79+
ghz_native_rewrite.callable_region.blocks[0],
80+
ghz_native.callable_region.blocks[0],
8081
)
8182

8283

@@ -92,10 +93,13 @@ def ghz_native(theta: float, phi: float, lam: float):
9293
qubits = squin.qalloc(1)
9394
native.u3(theta, phi, lam, qubits[0])
9495

96+
# unroll first to check that gete rewrites happen in ghz body
97+
AggressiveUnroll(ghz.dialects).fixpoint(ghz)
9598
ghz_native_rewrite = SquinToNative().emit(ghz)
9699
AggressiveUnroll(ghz.dialects).fixpoint(ghz_native_rewrite)
97100

98101
AggressiveUnroll(ghz_native.dialects).fixpoint(ghz_native)
99102
test_utils.assert_nodes(
100-
ghz_native_rewrite.callable_region, ghz_native.callable_region
103+
ghz_native_rewrite.callable_region.blocks[0],
104+
ghz_native.callable_region.blocks[0],
101105
)

0 commit comments

Comments
 (0)