Skip to content

Commit 045fa00

Browse files
committed
broke everything up into smaller rulesgst
1 parent 99e66af commit 045fa00

File tree

14 files changed

+289
-297
lines changed

14 files changed

+289
-297
lines changed

src/bloqade/qasm2/dialects/uop/stmts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class TwoQubitCtrlGate(ir.Statement):
2828

2929
@statement(dialect=dialect)
3030
class CX(TwoQubitCtrlGate):
31-
"""Alias for the CNOT or CH gate operations."""
31+
"""Alias for the CNOT or CX gate operations."""
3232

3333
name = "CX" # Note this is capitalized
3434

src/bloqade/squin/passes/qasm2_gate_func_to_squin.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,12 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
4040
Chain(
4141
QASM2ToPyRule(),
4242
qasm2_rule.QASM2CoreToSquin(),
43-
qasm2_rule.QASM2UOPToSquin(),
44-
qasm2_rule.QASM2NoiseToSquin(),
4543
qasm2_rule.QASM2GlobParallelToSquin(),
44+
qasm2_rule.QASM2NoiseToSquin(),
45+
qasm2_rule.QASM2IdToSquin(),
46+
qasm2_rule.QASM2UOp1QToSquin(),
47+
qasm2_rule.QASM2ParametrizedUOp1QToSquin(),
48+
qasm2_rule.QASM2UOp2QToSquin(),
4649
)
4750
)
4851

src/bloqade/squin/passes/qasm2_to_squin.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,13 @@
88

99
from bloqade import squin
1010
from bloqade.squin.rewrite.qasm2 import (
11-
QASM2UOPToSquin,
11+
QASM2IdToSquin,
1212
QASM2CoreToSquin,
1313
QASM2NoiseToSquin,
14+
QASM2UOp1QToSquin,
15+
QASM2UOp2QToSquin,
1416
QASM2GlobParallelToSquin,
17+
QASM2ParametrizedUOp1QToSquin,
1518
)
1619

1720
# There's a QASM2Py pass that only applies an _QASM2Py rewrite rule,
@@ -31,9 +34,12 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
3134
Chain(
3235
QASM2ToPyRule(),
3336
QASM2CoreToSquin(),
34-
QASM2UOPToSquin(),
3537
QASM2GlobParallelToSquin(),
3638
QASM2NoiseToSquin(),
39+
QASM2IdToSquin(),
40+
QASM2UOp1QToSquin(),
41+
QASM2ParametrizedUOp1QToSquin(),
42+
QASM2UOp2QToSquin(),
3743
)
3844
).rewrite(mt.code)
3945

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
from .uop_to_squin import QASM2UOPToSquin as QASM2UOPToSquin
1+
from .id_to_squin import QASM2IdToSquin as QASM2IdToSquin
22
from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin
33
from .noise_to_squin import QASM2NoiseToSquin as QASM2NoiseToSquin
4+
from .uop_1q_to_squin import QASM2UOp1QToSquin as QASM2UOp1QToSquin
5+
from .uop_2q_to_squin import QASM2UOp2QToSquin as QASM2UOp2QToSquin
46
from .glob_parallel_to_squin import QASM2GlobParallelToSquin as QASM2GlobParallelToSquin
7+
from .parametrized_uop_1q_to_squin import (
8+
QASM2ParametrizedUOp1QToSquin as QASM2ParametrizedUOp1QToSquin,
9+
)

src/bloqade/squin/rewrite/qasm2/core_to_squin.py

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,34 @@
55
from bloqade import squin
66
from bloqade.qasm2.dialects.core import stmts as core_stmts
77

8+
CORE_TO_SQUIN_MAP = {
9+
core_stmts.QRegNew: squin.qubit.qalloc,
10+
core_stmts.Reset: squin.qubit.reset,
11+
}
12+
813

914
class QASM2CoreToSquin(RewriteRule):
1015

1116
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
1217

13-
match node:
14-
case core_stmts.QRegNew(n_qubits=n_qubits):
15-
qalloc_invoke_stmt = func.Invoke(
16-
callee=squin.qubit.qalloc, inputs=(n_qubits,)
17-
)
18-
node.replace_by(qalloc_invoke_stmt)
19-
case core_stmts.Reset(qarg=qarg):
20-
reset_invoke_stmt = func.Invoke(
21-
callee=squin.qubit.reset, inputs=(qarg,)
22-
)
23-
node.replace_by(reset_invoke_stmt)
24-
case core_stmts.QRegGet(reg=reg, idx=idx):
25-
get_item_stmt = py.GetItem(
26-
obj=reg,
27-
index=idx,
28-
)
29-
node.replace_by(get_item_stmt)
30-
case _:
31-
return RewriteResult()
18+
if isinstance(node, core_stmts.QRegGet):
19+
py_get_item = py.GetItem(
20+
obj=node.reg,
21+
index=node.idx,
22+
)
23+
node.replace_by(py_get_item)
24+
return RewriteResult(has_done_something=True)
25+
26+
if isinstance(node, core_stmts.QRegNew):
27+
args = (node.n_qubits,)
28+
elif isinstance(node, core_stmts.Reset):
29+
args = (node.qarg,)
30+
else:
31+
return RewriteResult()
3232

33+
new_stmt = func.Invoke(
34+
callee=CORE_TO_SQUIN_MAP[type(node)],
35+
inputs=args,
36+
)
37+
node.replace_by(new_stmt)
3338
return RewriteResult(has_done_something=True)

src/bloqade/squin/rewrite/qasm2/glob_parallel_to_squin.py

Lines changed: 21 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,47 +5,30 @@
55
from bloqade import squin
66
from bloqade.qasm2.dialects import glob, parallel
77

8+
GLOBAL_PARALLEL_TO_SQUIN_MAP = {
9+
glob.UGate: squin.broadcast.u3,
10+
parallel.UGate: squin.broadcast.u3,
11+
parallel.RZ: squin.broadcast.rz,
12+
}
13+
814

915
class QASM2GlobParallelToSquin(RewriteRule):
1016

1117
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
1218

13-
match node:
14-
case glob.UGate() | parallel.UGate() | parallel.RZ():
15-
return self.rewrite_1q_gates(node)
16-
case _:
17-
return RewriteResult()
18-
19-
return RewriteResult(has_done_something=True)
20-
21-
def rewrite_1q_gates(
22-
self, stmt: glob.UGate | parallel.UGate | parallel.RZ
23-
) -> RewriteResult:
24-
25-
match stmt:
26-
case glob.UGate(theta=theta, phi=phi, lam=lam) | parallel.UGate(
27-
theta=theta, phi=phi, lam=lam
28-
):
29-
# ever so slight naming difference,
30-
# exists because intended semantics are different
31-
match stmt:
32-
case glob.UGate():
33-
qargs = stmt.registers
34-
case parallel.UGate():
35-
qargs = stmt.qargs
36-
37-
invoke_u_broadcast_stmt = func.Invoke(
38-
callee=squin.broadcast.u3,
39-
inputs=(theta, phi, lam, qargs),
40-
)
41-
stmt.replace_by(invoke_u_broadcast_stmt)
42-
case parallel.RZ(theta=theta, qargs=qargs):
43-
invoke_rz_broadcast_stmt = func.Invoke(
44-
callee=squin.broadcast.rz,
45-
inputs=(theta, qargs),
46-
)
47-
stmt.replace_by(invoke_rz_broadcast_stmt)
48-
case _:
49-
return RewriteResult()
50-
19+
if isinstance(node, glob.UGate):
20+
args = (node.theta, node.phi, node.lam, node.registers)
21+
elif isinstance(node, parallel.UGate):
22+
args = (node.theta, node.phi, node.lam, node.qargs)
23+
elif isinstance(node, parallel.RZ):
24+
args = (node.theta, node.qargs)
25+
else:
26+
return RewriteResult()
27+
28+
squin_equivalent_stmt = GLOBAL_PARALLEL_TO_SQUIN_MAP[type(node)]
29+
invoke_stmt = func.Invoke(
30+
callee=squin_equivalent_stmt,
31+
inputs=args,
32+
)
33+
node.replace_by(invoke_stmt)
5134
return RewriteResult(has_done_something=True)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
from kirin import ir
2+
from kirin.rewrite.abc import RewriteRule, RewriteResult
3+
4+
import bloqade.qasm2.dialects.uop.stmts as uop_stmts
5+
6+
7+
class QASM2IdToSquin(RewriteRule):
8+
9+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
10+
11+
if not isinstance(node, uop_stmts.Id):
12+
return RewriteResult()
13+
14+
node.delete()
15+
return RewriteResult(has_done_something=True)

src/bloqade/squin/rewrite/qasm2/noise_to_squin.py

Lines changed: 30 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,43 @@
11
from kirin import ir
2-
from kirin.dialects import py, func
2+
from kirin.dialects import func
33
from kirin.rewrite.abc import RewriteRule, RewriteResult
44

55
from bloqade import squin
66
from bloqade.qasm2.dialects.noise import stmts as noise_stmts
77

8+
from .util import num_to_py_constant
89

9-
class QASM2NoiseToSquin(RewriteRule):
10-
11-
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
12-
13-
match node:
14-
case noise_stmts.AtomLossChannel():
15-
return self.rewrite_AtomLossChannel(node)
16-
case noise_stmts.PauliChannel():
17-
return self.rewrite_PauliChannel(node)
18-
case noise_stmts.CZPauliChannel():
19-
return self.rewrite_CZPauliChannel(node)
20-
case _:
21-
return RewriteResult()
22-
23-
return RewriteResult()
10+
NOISE_TO_SQUIN_MAP = {
11+
noise_stmts.AtomLossChannel: squin.broadcast.qubit_loss,
12+
noise_stmts.PauliChannel: squin.broadcast.single_qubit_pauli_channel,
13+
}
2414

25-
def rewrite_AtomLossChannel(
26-
self, stmt: noise_stmts.AtomLossChannel
27-
) -> RewriteResult:
2815

29-
qargs = stmt.qargs
30-
# this is a raw float, not in SSA form yet!
31-
prob = stmt.prob
32-
prob_stmt = py.Constant(value=prob)
33-
prob_stmt.insert_before(stmt)
34-
35-
invoke_loss_stmt = func.Invoke(
36-
callee=squin.broadcast.qubit_loss,
37-
inputs=(prob_stmt.result, qargs),
38-
)
39-
40-
stmt.replace_by(invoke_loss_stmt)
41-
42-
return RewriteResult(has_done_something=True)
43-
44-
def rewrite_PauliChannel(self, stmt: noise_stmts.PauliChannel) -> RewriteResult:
45-
46-
qargs = stmt.qargs
47-
p_x = stmt.px
48-
p_y = stmt.py
49-
p_z = stmt.pz
50-
51-
probs = [p_x, p_y, p_z]
52-
probs_ssas = []
16+
class QASM2NoiseToSquin(RewriteRule):
5317

54-
for prob in probs:
55-
prob_stmt = py.Constant(value=prob)
56-
prob_stmt.insert_before(stmt)
57-
probs_ssas.append(prob_stmt.result)
18+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
5819

59-
invoke_pauli_channel_stmt = func.Invoke(
60-
callee=squin.broadcast.single_qubit_pauli_channel,
61-
inputs=(*probs_ssas, qargs),
20+
if isinstance(node, noise_stmts.AtomLossChannel):
21+
qargs = node.qargs
22+
prob = node.prob
23+
prob_ssas = num_to_py_constant([prob], stmt_to_insert_before=node)
24+
elif isinstance(node, noise_stmts.PauliChannel):
25+
qargs = node.qargs
26+
p_x = node.px
27+
p_y = node.py
28+
p_z = node.pz
29+
prob_ssas = num_to_py_constant([p_x, p_y, p_z], stmt_to_insert_before=node)
30+
elif isinstance(node, noise_stmts.CZPauliChannel):
31+
return self.rewrite_CZPauliChannel(node)
32+
else:
33+
return RewriteResult()
34+
35+
squin_noise_stmt = NOISE_TO_SQUIN_MAP[type(node)]
36+
invoke_stmt = func.Invoke(
37+
callee=squin_noise_stmt,
38+
inputs=(*prob_ssas, qargs),
6239
)
63-
64-
stmt.replace_by(invoke_pauli_channel_stmt)
40+
node.replace_by(invoke_stmt)
6541
return RewriteResult(has_done_something=True)
6642

6743
def rewrite_CZPauliChannel(self, stmt: noise_stmts.CZPauliChannel) -> RewriteResult:
@@ -78,11 +54,8 @@ def rewrite_CZPauliChannel(self, stmt: noise_stmts.CZPauliChannel) -> RewriteRes
7854

7955
error_probs = [px_ctrl, py_ctrl, pz_ctrl, px_qarg, py_qarg, pz_qarg]
8056
# first half of entries for control qubits, other half for targets
81-
error_prob_ssas = []
82-
for error_prob in error_probs:
83-
error_prob_stmt = py.Constant(value=error_prob)
84-
error_prob_stmt.insert_before(stmt)
85-
error_prob_ssas.append(error_prob_stmt.result)
57+
58+
error_prob_ssas = num_to_py_constant(error_probs, stmt_to_insert_before=stmt)
8659

8760
ctrl_pauli_channel_invoke = func.Invoke(
8861
callee=squin.broadcast.single_qubit_pauli_channel,
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from math import pi
2+
3+
from kirin import ir
4+
from kirin.dialects import py, func
5+
from kirin.rewrite.abc import RewriteRule, RewriteResult
6+
7+
from bloqade import squin
8+
from bloqade.qasm2.dialects.uop import stmts as uop_stmts
9+
10+
PARAMETRIZED_1Q_GATES_TO_SQUIN_MAP = {
11+
uop_stmts.UGate: squin.u3,
12+
uop_stmts.U1: squin.u3,
13+
uop_stmts.U2: squin.u3,
14+
uop_stmts.RZ: squin.rz,
15+
uop_stmts.RX: squin.rx,
16+
uop_stmts.RY: squin.ry,
17+
}
18+
19+
20+
class QASM2ParametrizedUOp1QToSquin(RewriteRule):
21+
22+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
23+
24+
if isinstance(node, (uop_stmts.RX, uop_stmts.RY, uop_stmts.RZ)):
25+
args = (node.theta, node.qarg)
26+
elif isinstance(node, (uop_stmts.UGate)):
27+
args = (node.theta, node.phi, node.lam, node.qarg)
28+
elif isinstance(node, (uop_stmts.U1)):
29+
zero_stmt = py.Constant(value=0.0)
30+
zero_stmt.insert_before(node)
31+
args = (zero_stmt.result, zero_stmt.result, node.lam, node.qarg)
32+
elif isinstance(node, (uop_stmts.U2)):
33+
half_pi_stmt = py.Constant(value=pi / 2)
34+
half_pi_stmt.insert_before(node)
35+
args = (half_pi_stmt.result, node.phi, node.lam, node.qarg)
36+
else:
37+
return RewriteResult()
38+
39+
squin_equivalent_stmt = PARAMETRIZED_1Q_GATES_TO_SQUIN_MAP[type(node)]
40+
invoke_stmt = func.Invoke(
41+
callee=squin_equivalent_stmt,
42+
inputs=args,
43+
)
44+
node.replace_by(invoke_stmt)
45+
46+
return RewriteResult(has_done_something=True)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from kirin import ir
2+
from kirin.dialects import func
3+
from kirin.rewrite.abc import RewriteRule, RewriteResult
4+
5+
from bloqade import squin
6+
from bloqade.qasm2.dialects.uop import stmts as uop_stmts
7+
8+
ONE_Q_GATES_TO_SQUIN_MAP = {
9+
uop_stmts.X: squin.x,
10+
uop_stmts.Y: squin.y,
11+
uop_stmts.Z: squin.z,
12+
uop_stmts.H: squin.h,
13+
uop_stmts.S: squin.s,
14+
uop_stmts.T: squin.t,
15+
uop_stmts.SX: squin.sqrt_x,
16+
}
17+
18+
19+
class QASM2UOp1QToSquin(RewriteRule):
20+
21+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
22+
23+
squin_1q_gate = ONE_Q_GATES_TO_SQUIN_MAP.get(type(node))
24+
if squin_1q_gate is None:
25+
return RewriteResult()
26+
27+
invoke_stmt = func.Invoke(
28+
callee=squin_1q_gate,
29+
inputs=(node.qarg,),
30+
)
31+
node.replace_by(invoke_stmt)
32+
33+
return RewriteResult(has_done_something=True)

0 commit comments

Comments
 (0)