Skip to content

Commit a3b7124

Browse files
committed
merging main
2 parents 0807726 + d079cc2 commit a3b7124

File tree

25 files changed

+310
-266
lines changed

25 files changed

+310
-266
lines changed

src/bloqade/cirq_utils/emit/gate.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,19 @@ def rot(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: gate.stmts.RotationGat
8686

8787
frame.circuit.append(cirq_op.on_each(qubits))
8888
return ()
89+
90+
@impl(gate.stmts.U3)
91+
def u3(self, emit: EmitCirq, frame: EmitCirqFrame, stmt: gate.stmts.U3):
92+
qubits = frame.get(stmt.qubits)
93+
94+
theta = frame.get(stmt.theta) * 2 * math.pi
95+
phi = frame.get(stmt.phi) * 2 * math.pi
96+
lam = frame.get(stmt.lam) * 2 * math.pi
97+
98+
frame.circuit.append(cirq.Rz(rads=lam).on_each(*qubits))
99+
100+
frame.circuit.append(cirq.Ry(rads=theta).on_each(*qubits))
101+
102+
frame.circuit.append(cirq.Rz(rads=phi).on_each(*qubits))
103+
104+
return ()

src/bloqade/native/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
ry as ry,
1515
rz as rz,
1616
u3 as u3,
17-
rot as rot,
1817
s_dag as s_dag,
1918
shift as shift,
2019
sqrt_x as sqrt_x,

src/bloqade/native/stdlib/broadcast.py

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -187,33 +187,23 @@ def shift(angle: float, qubits: ilist.IList[qubit.Qubit, Any]):
187187
rz(angle / 2.0, qubits)
188188

189189

190-
@kernel
191-
def rot(phi: float, theta: float, omega: float, qubits: ilist.IList[qubit.Qubit, Any]):
192-
"""Apply a general single-qubit rotation on a group of qubits.
193-
194-
Args:
195-
phi (float): Z rotation before Y (radians).
196-
theta (float): Y rotation (radians).
197-
omega (float): Z rotation after Y (radians).
198-
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
199-
"""
200-
rz(phi, qubits)
201-
ry(theta, qubits)
202-
rz(omega, qubits)
203-
204-
205190
@kernel
206191
def u3(theta: float, phi: float, lam: float, qubits: ilist.IList[qubit.Qubit, Any]):
207192
"""Apply the U3 gate on a group of qubits.
208193
194+
The applied gate is represented by the unitary matrix given by:
195+
196+
$$ U3(\\theta, \\phi, \\lambda) = R_z(\\phi)R_y(\\theta)R_z(\\lambda) $$
197+
209198
Args:
210199
theta (float): Rotation around Y axis (radians).
211200
phi (float): Global phase shift component (radians).
212201
lam (float): Z rotations in decomposition (radians).
213202
qubits (ilist.IList[qubit.Qubit, Any]): Target qubits.
214203
"""
215-
rot(lam, theta, -lam, qubits)
216-
shift(phi + lam, qubits)
204+
rz(lam, qubits)
205+
ry(theta, qubits)
206+
rz(phi, qubits)
217207

218208

219209
N = TypeVar("N")

src/bloqade/native/stdlib/simple.py

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -170,23 +170,14 @@ def shift(angle: float, qubit: qubit.Qubit):
170170
broadcast.shift(angle, ilist.IList([qubit]))
171171

172172

173-
@kernel
174-
def rot(phi: float, theta: float, omega: float, qubit: qubit.Qubit):
175-
"""Apply a general single-qubit rotation on a single qubit.
176-
177-
Args:
178-
phi (float): Z rotation before Y (radians).
179-
theta (float): Y rotation (radians).
180-
omega (float): Z rotation after Y (radians).
181-
qubit (qubit.Qubit): The qubit to apply the rotation to.
182-
"""
183-
broadcast.rot(phi, theta, omega, ilist.IList([qubit]))
184-
185-
186173
@kernel
187174
def u3(theta: float, phi: float, lam: float, qubit: qubit.Qubit):
188175
"""Apply the U3 gate on a single qubit.
189176
177+
The applied gate is represented by the unitary matrix given by:
178+
179+
$$ U3(\\theta, \\phi, \\lambda) = R_z(\\phi)R_y(\\theta)R_z(\\lambda) $$
180+
190181
Args:
191182
theta (float): Rotation angle around the Y axis in radians.
192183
phi (float): Rotation angle around the Z axis in radians.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .squin2native import (
2+
GateRule as GateRule,
23
SquinToNative as SquinToNative,
34
SquinToNativePass as SquinToNativePass,
45
)

src/bloqade/native/upstream/squin2native.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class GateRule(RewriteRule):
2727
stmts.CX: (broadcast.cx,),
2828
stmts.CY: (broadcast.cy,),
2929
stmts.CZ: (broadcast.cz,),
30+
stmts.U3: (broadcast.u3,),
3031
}
3132

3233
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:

src/bloqade/pyqrack/squin/gate/gate.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
CX,
1313
CY,
1414
CZ,
15+
U3,
1516
H,
1617
S,
1718
T,
@@ -120,3 +121,16 @@ def control(
120121
for control, target in zip(controls, targets):
121122
if control.is_active() and target.is_active():
122123
getattr(control.sim_reg, method_name)([control.addr], target.addr)
124+
125+
@interp.impl(U3)
126+
def u3(self, interp: PyQrackInterpreter, frame: interp.Frame, stmt: U3):
127+
theta = frame.get(stmt.theta) * 2 * math.pi
128+
phi = frame.get(stmt.phi) * 2 * math.pi
129+
lam = frame.get(stmt.lam) * 2 * math.pi
130+
qubits: ilist.IList[PyQrackQubit, Any] = frame.get(stmt.qubits)
131+
132+
for qbit in qubits:
133+
if not qbit.is_active():
134+
continue
135+
136+
qbit.sim_reg.u(qbit.addr, theta, phi, lam)

src/bloqade/qasm2/passes/fold.py

Lines changed: 14 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,12 @@
11
from dataclasses import field, dataclass
22

33
from kirin import ir
4-
from kirin.passes import Pass, TypeInfer
5-
from kirin.rewrite import (
6-
Walk,
7-
Chain,
8-
Inline,
9-
Fixpoint,
10-
WrapConst,
11-
Call2Invoke,
12-
ConstantFold,
13-
CFGCompactify,
14-
InlineGetItem,
15-
InlineGetField,
16-
DeadCodeElimination,
17-
CommonSubexpressionElimination,
18-
)
19-
from kirin.analysis import const
20-
from kirin.dialects import scf, ilist
4+
from kirin.passes import Pass
215
from kirin.ir.method import Method
226
from kirin.rewrite.abc import RewriteResult
237

248
from bloqade.qasm2.dialects import expr
9+
from bloqade.rewrite.passes import AggressiveUnroll
2510

2611
from .unroll_if import UnrollIfs
2712

@@ -30,71 +15,27 @@
3015
class QASM2Fold(Pass):
3116
"""Fold pass for qasm2.extended"""
3217

33-
constprop: const.Propagate = field(init=False)
3418
inline_gate_subroutine: bool = True
3519
unroll_ifs: bool = True
20+
aggressive_unroll: AggressiveUnroll = field(init=False)
3621

3722
def __post_init__(self):
38-
self.constprop = const.Propagate(self.dialects)
39-
self.typeinfer = TypeInfer(self.dialects)
23+
def inline_simple(node: ir.Statement):
24+
if isinstance(node, expr.GateFunction):
25+
return self.inline_gate_subroutine
4026

41-
def unsafe_run(self, mt: Method) -> RewriteResult:
42-
result = RewriteResult()
43-
frame, _ = self.constprop.run_analysis(mt)
44-
result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
45-
rule = Chain(
46-
ConstantFold(),
47-
Call2Invoke(),
48-
InlineGetField(),
49-
InlineGetItem(),
50-
DeadCodeElimination(),
51-
CommonSubexpressionElimination(),
52-
)
53-
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
27+
return True
5428

55-
result = (
56-
Walk(
57-
Chain(
58-
scf.unroll.PickIfElse(),
59-
scf.unroll.ForLoop(),
60-
scf.trim.UnusedYield(),
61-
)
62-
)
63-
.rewrite(mt.code)
64-
.join(result)
29+
self.aggressive_unroll = AggressiveUnroll(
30+
self.dialects, inline_simple, no_raise=self.no_raise
6531
)
6632

67-
if self.unroll_ifs:
68-
UnrollIfs(mt.dialects).unsafe_run(mt).join(result)
69-
70-
# run typeinfer again after unroll etc. because we now insert
71-
# a lot of new nodes, which might have more precise types
72-
self.typeinfer.unsafe_run(mt)
73-
result = (
74-
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
75-
.rewrite(mt.code)
76-
.join(result)
77-
)
78-
79-
def inline_simple(node: ir.Statement):
80-
if isinstance(node, expr.GateFunction):
81-
return self.inline_gate_subroutine
33+
def unsafe_run(self, mt: Method) -> RewriteResult:
34+
result = RewriteResult()
8235

83-
if not isinstance(node.parent_stmt, (scf.For, scf.IfElse)):
84-
return True # always inline calls outside of loops and if-else
36+
if self.unroll_ifs:
37+
result = UnrollIfs(mt.dialects).unsafe_run(mt).join(result)
8538

86-
# inside loops and if-else, only inline simple functions, i.e. functions with a single block
87-
if (trait := node.get_trait(ir.CallableStmtInterface)) is None:
88-
return False # not a callable, don't inline to be safe
89-
region = trait.get_callable_region(node)
90-
return len(region.blocks) == 1
39+
result = self.aggressive_unroll.unsafe_run(mt).join(result)
9140

92-
result = (
93-
Walk(
94-
Inline(inline_simple),
95-
)
96-
.rewrite(mt.code)
97-
.join(result)
98-
)
99-
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
10041
return result
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .aggressive_unroll import AggressiveUnroll as AggressiveUnroll
12
from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Callable
2+
from dataclasses import field, dataclass
3+
4+
from kirin import ir
5+
from kirin.passes import Pass, HintConst, TypeInfer
6+
from kirin.rewrite import (
7+
Walk,
8+
Chain,
9+
Inline,
10+
Fixpoint,
11+
Call2Invoke,
12+
ConstantFold,
13+
CFGCompactify,
14+
InlineGetItem,
15+
InlineGetField,
16+
DeadCodeElimination,
17+
)
18+
from kirin.dialects import scf, ilist
19+
from kirin.ir.method import Method
20+
from kirin.rewrite.abc import RewriteResult
21+
from kirin.rewrite.cse import CommonSubexpressionElimination
22+
from kirin.passes.aggressive import UnrollScf
23+
24+
25+
@dataclass
26+
class Fold(Pass):
27+
hint_const: HintConst = field(init=False)
28+
29+
def __post_init__(self):
30+
self.hint_const = HintConst(self.dialects, no_raise=self.no_raise)
31+
32+
def unsafe_run(self, mt: Method) -> RewriteResult:
33+
result = RewriteResult()
34+
result = self.hint_const.unsafe_run(mt).join(result)
35+
rule = Chain(
36+
ConstantFold(),
37+
Call2Invoke(),
38+
InlineGetField(),
39+
InlineGetItem(),
40+
ilist.rewrite.InlineGetItem(),
41+
ilist.rewrite.HintLen(),
42+
)
43+
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
44+
45+
return result
46+
47+
48+
@dataclass
49+
class AggressiveUnroll(Pass):
50+
"""A pass to unroll structured control flow"""
51+
52+
additional_inline_heuristic: Callable[[ir.Statement], bool] = lambda node: True
53+
54+
fold: Fold = field(init=False)
55+
typeinfer: TypeInfer = field(init=False)
56+
scf_unroll: UnrollScf = field(init=False)
57+
58+
def __post_init__(self):
59+
self.fold = Fold(self.dialects, no_raise=self.no_raise)
60+
self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise)
61+
self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise)
62+
63+
def unsafe_run(self, mt: Method) -> RewriteResult:
64+
result = RewriteResult()
65+
result = self.scf_unroll.unsafe_run(mt).join(result)
66+
result = (
67+
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
68+
.rewrite(mt.code)
69+
.join(result)
70+
)
71+
result = self.typeinfer.unsafe_run(mt).join(result)
72+
result = self.fold.unsafe_run(mt).join(result)
73+
result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
74+
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
75+
76+
rule = Chain(
77+
CommonSubexpressionElimination(),
78+
DeadCodeElimination(),
79+
)
80+
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
81+
82+
return result
83+
84+
def inline_heuristic(self, node: ir.Statement) -> bool:
85+
"""The heuristic to decide whether to inline a function call or not.
86+
inside loops and if-else, only inline simple functions, i.e.
87+
functions with a single block
88+
"""
89+
return not isinstance(
90+
node.parent_stmt, (scf.For, scf.IfElse)
91+
) and self.additional_inline_heuristic(
92+
node
93+
) # always inline calls outside of loops and if-else

0 commit comments

Comments
 (0)