Skip to content

Commit e107a87

Browse files
committed
Fix the UopToParallel pass for qasm2 constants (#477)
This fixes a bug, where the rewrite bailed if the qubit number was a `qasm2.expr.ConstInt` rather than a `py.Constant`. That meant that loaded qasm2 kernels weren't rewritten at all.
1 parent 1b5999c commit e107a87

File tree

3 files changed

+13
-3
lines changed

3 files changed

+13
-3
lines changed

src/bloqade/qasm2/rewrite/register.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from kirin.dialects import py
33
from kirin.rewrite.abc import RewriteRule, RewriteResult
44

5-
from bloqade.qasm2.dialects import core
5+
from bloqade.qasm2.dialects import core, expr
66

77

88
class RaiseRegisterRule(RewriteRule):
@@ -26,7 +26,7 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
2626
n_qubits_ref = node.n_qubits
2727

2828
n_qubits = n_qubits_ref.owner
29-
if isinstance(n_qubits, py.Constant):
29+
if isinstance(n_qubits, py.Constant | expr.ConstInt):
3030
# case where the n_qubits comes from a constant
3131
new_n_qubits = n_qubits.from_stmt(n_qubits)
3232
new_n_qubits.insert_before(first_stmt)

src/bloqade/qasm2/rewrite/uop_to_parallel.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from kirin.analysis.const import lattice
99

1010
from bloqade.analysis import address
11-
from bloqade.qasm2.dialects import uop, core, parallel
11+
from bloqade.qasm2.dialects import uop, core, expr, parallel
1212
from bloqade.squin.analysis.schedule import StmtDag
1313

1414

@@ -194,6 +194,8 @@ def move_and_collect_qubit_list(
194194
new_qubits.append(new_qubit.result)
195195
case core.QRegGet(
196196
reg=reg, idx=ir.ResultValue(stmt=py.Constant() as idx)
197+
) | core.QRegGet(
198+
reg=reg, idx=ir.ResultValue(stmt=expr.ConstInt() as idx)
197199
):
198200
(new_idx := idx.from_stmt(idx)).insert_before(node)
199201
(

test/qasm2/passes/test_uop_to_parallel.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ def test():
3838
# add this to raise error if there are broken ssa references
3939
_, _ = address.AddressAnalysis(test.dialects).run_analysis(test, no_raise=False)
4040

41+
# check that there's parallel statements now
42+
assert any(
43+
[
44+
isinstance(stmt, qasm2.dialects.parallel.UGate)
45+
for stmt in test.callable_region.blocks[0].stmts
46+
]
47+
)
48+
4149

4250
def test_two():
4351

0 commit comments

Comments
 (0)