Skip to content

Commit ba31259

Browse files
committed
clean up import for stim
1 parent 23c995b commit ba31259

33 files changed

+448
-268
lines changed
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
11
from . import _emit as _emit, address as address, _typeinfer as _typeinfer
2-
from .stmts import * # noqa: F403
2+
from .stmts import (
3+
QRegNew as QRegNew,
4+
QRegGet as QRegGet,
5+
CRegNew as CRegNew,
6+
CRegGet as CRegGet,
7+
Reset as Reset,
8+
Measure as Measure,
9+
CRegEq as CRegEq,
10+
)
311
from ._dialect import dialect as dialect
Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
11
from . import _emit as _emit, _interp as _interp, _from_python as _from_python
2-
from .stmts import * # noqa: F403
2+
from .stmts import (
3+
GateFunction as GateFunction,
4+
ConstFloat as ConstFloat,
5+
ConstInt as ConstInt,
6+
ConstPI as ConstPI,
7+
Neg as Neg,
8+
Sin as Sin,
9+
Cos as Cos,
10+
Tan as Tan,
11+
Exp as Exp,
12+
Log as Log,
13+
Sqrt as Sqrt,
14+
Add as Add,
15+
Sub as Sub,
16+
Mul as Mul,
17+
Div as Div,
18+
Pow as Pow,
19+
)
320
from ._dialect import dialect as dialect
Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,41 @@
11
from . import _emit as _emit, stmts as stmts
2-
from .stmts import * # noqa: F403
2+
from .stmts import (
3+
UGate as UGate,
4+
CX as CX,
5+
Barrier as Barrier,
6+
Id as Id,
7+
H as H,
8+
X as X,
9+
Y as Y,
10+
Z as Z,
11+
S as S,
12+
Sdag as Sdag,
13+
SX as SX,
14+
SXdag as SXdag,
15+
T as T,
16+
Tdag as Tdag,
17+
RX as RX,
18+
RY as RY,
19+
RZ as RZ,
20+
U1 as U1,
21+
U2 as U2,
22+
CZ as CZ,
23+
CY as CY,
24+
CSX as CSX,
25+
Swap as Swap,
26+
CH as CH,
27+
CCX as CCX,
28+
CSwap as CSwap,
29+
CRX as CRX,
30+
CRY as CRY,
31+
CRZ as CRZ,
32+
CU1 as CU1,
33+
CU3 as CU3,
34+
CU as CU,
35+
RXX as RXX,
36+
RZZ as RZZ,
37+
SingleQubitGate as SingleQubitGate,
38+
TwoQubitCtrlGate as TwoQubitCtrlGate,
39+
)
340
from ._dialect import dialect as dialect
4-
from .schedule import * # noqa: F403
41+
from . import schedule as schedule

src/bloqade/qasm2/dialects/uop/schedule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
@dialect.register(key="qasm2.schedule.dag")
11-
class UOp(interp.MethodTable):
11+
class UOpSchedule(interp.MethodTable):
1212

1313
@interp.impl(stmts.Id)
1414
@interp.impl(stmts.SXdag)

src/bloqade/qasm2/rewrite/desugar.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
from kirin.dialects import py
77

88
from bloqade.qasm2.dialects import core
9+
from bloqade.qasm2 import types
910

1011

1112
class IndexingDesugarRule(abc.RewriteRule):
1213
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
1314
if isinstance(node, py.indexing.GetItem):
14-
if node.obj.type.is_subseteq(core.QRegType):
15+
if node.obj.type.is_subseteq(types.QRegType):
1516
node.replace_by(core.QRegGet(reg=node.obj, idx=node.index))
1617
return abc.RewriteResult(has_done_something=True)
17-
elif node.obj.type.is_subseteq(core.CRegType):
18+
elif node.obj.type.is_subseteq(types.CRegType):
1819
node.replace_by(core.CRegGet(reg=node.obj, idx=node.index))
1920
return abc.RewriteResult(has_done_something=True)
2021

src/bloqade/squin/groups.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@ def run_pass(method):
4444
py_mult_to_mult_pass(method)
4545

4646
return run_pass
47+
Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,10 @@
1-
# Put all the proper wrappers here
2-
3-
from kirin.lowering import wraps as _wraps
4-
5-
from bloqade.squin.op.types import Op
6-
71
from . import stmts as stmts
2+
from ._dialect import dialect as dialect
3+
from ._wrapper import (
4+
pauli_channel as pauli_channel,
5+
pp_error as pp_error,
6+
depolarize as depolarize,
7+
pauli_channel as pauli_channel,
8+
qubit_loss as qubit_loss,
9+
)
810

9-
10-
@_wraps(stmts.PauliError)
11-
def pauli_error(basis: Op, p: float) -> Op: ...
12-
13-
14-
@_wraps(stmts.PPError)
15-
def pp_error(op: Op, p: float) -> Op: ...
16-
17-
18-
@_wraps(stmts.Depolarize)
19-
def depolarize(n_qubits: int, p: float) -> Op: ...
20-
21-
22-
@_wraps(stmts.PauliChannel)
23-
def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
24-
25-
26-
@_wraps(stmts.QubitLoss)
27-
def qubit_loss(p: float) -> Op: ...
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from kirin.lowering import wraps
2+
3+
from bloqade.squin.op.types import Op
4+
from . import stmts
5+
6+
@wraps(stmts.PauliError)
7+
def pauli_error(basis: Op, p: float) -> Op: ...
8+
9+
10+
@wraps(stmts.PPError)
11+
def pp_error(op: Op, p: float) -> Op: ...
12+
13+
14+
@wraps(stmts.Depolarize)
15+
def depolarize(n_qubits: int, p: float) -> Op: ...
16+
17+
18+
@wraps(stmts.PauliChannel)
19+
def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
20+
21+
22+
@wraps(stmts.QubitLoss)
23+
def qubit_loss(p: float) -> Op: ...

src/bloqade/squin/op/__init__.py

Lines changed: 33 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -1,162 +1,36 @@
1-
from kirin import ir as _ir
2-
from kirin.prelude import structural_no_opt as _structural_no_opt
3-
from kirin.lowering import wraps as _wraps
4-
51
from . import stmts as stmts, types as types, rewrite as rewrite
62
from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
73
from ._dialect import dialect as dialect
8-
9-
10-
@_wraps(stmts.Kron)
11-
def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
12-
13-
14-
@_wraps(stmts.Mult)
15-
def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ...
16-
17-
18-
@_wraps(stmts.Scale)
19-
def scale(op: types.Op, factor: complex) -> types.Op: ...
20-
21-
22-
@_wraps(stmts.Adjoint)
23-
def adjoint(op: types.Op) -> types.Op: ...
24-
25-
26-
@_wraps(stmts.Control)
27-
def control(op: types.Op, *, n_controls: int) -> types.Op:
28-
"""
29-
Create a controlled operator.
30-
31-
Note, that when considering atom loss, the operator will not be applied if
32-
any of the controls has been lost.
33-
34-
Args:
35-
operator: The operator to apply under the control.
36-
n_controls: The number qubits to be used as control.
37-
38-
Returns:
39-
Operator
40-
"""
41-
...
42-
43-
44-
@_wraps(stmts.Identity)
45-
def identity(*, sites: int) -> types.Op: ...
46-
47-
48-
@_wraps(stmts.Rot)
49-
def rot(axis: types.Op, angle: float) -> types.Op: ...
50-
51-
52-
@_wraps(stmts.ShiftOp)
53-
def shift(theta: float) -> types.Op: ...
54-
55-
56-
@_wraps(stmts.PhaseOp)
57-
def phase(theta: float) -> types.Op: ...
58-
59-
60-
@_wraps(stmts.X)
61-
def x() -> types.Op: ...
62-
63-
64-
@_wraps(stmts.Y)
65-
def y() -> types.Op: ...
66-
67-
68-
@_wraps(stmts.Z)
69-
def z() -> types.Op: ...
70-
71-
72-
@_wraps(stmts.H)
73-
def h() -> types.Op: ...
74-
75-
76-
@_wraps(stmts.S)
77-
def s() -> types.Op: ...
78-
79-
80-
@_wraps(stmts.T)
81-
def t() -> types.Op: ...
82-
83-
84-
@_wraps(stmts.P0)
85-
def p0() -> types.Op: ...
86-
87-
88-
@_wraps(stmts.P1)
89-
def p1() -> types.Op: ...
90-
91-
92-
@_wraps(stmts.Sn)
93-
def spin_n() -> types.Op: ...
94-
95-
96-
@_wraps(stmts.Sp)
97-
def spin_p() -> types.Op: ...
98-
99-
100-
@_wraps(stmts.U3)
101-
def u(theta: float, phi: float, lam: float) -> types.Op: ...
102-
103-
104-
@_wraps(stmts.PauliString)
105-
def pauli_string(*, string: str) -> types.Op: ...
106-
107-
108-
# stdlibs
109-
@_ir.dialect_group(_structural_no_opt.add(dialect))
110-
def op(self):
111-
def run_pass(method):
112-
pass
113-
114-
return run_pass
115-
116-
117-
@op
118-
def rx(theta: float) -> types.Op:
119-
"""Rotation X gate."""
120-
return rot(x(), theta)
121-
122-
123-
@op
124-
def ry(theta: float) -> types.Op:
125-
"""Rotation Y gate."""
126-
return rot(y(), theta)
127-
128-
129-
@op
130-
def rz(theta: float) -> types.Op:
131-
"""Rotation Z gate."""
132-
return rot(z(), theta)
133-
134-
135-
@op
136-
def cx() -> types.Op:
137-
"""Controlled X gate."""
138-
return control(x(), n_controls=1)
139-
140-
141-
@op
142-
def cy() -> types.Op:
143-
"""Controlled Y gate."""
144-
return control(y(), n_controls=1)
145-
146-
147-
@op
148-
def cz() -> types.Op:
149-
"""Control Z gate."""
150-
return control(z(), n_controls=1)
151-
152-
153-
@op
154-
def ch() -> types.Op:
155-
"""Control H gate."""
156-
return control(h(), n_controls=1)
157-
158-
159-
@op
160-
def cphase(theta: float) -> types.Op:
161-
"""Control Phase gate."""
162-
return control(phase(theta), n_controls=1)
4+
from ._wrapper import (
5+
kron as kron,
6+
mult as mult,
7+
scale as scale,
8+
adjoint as adjoint,
9+
control as control,
10+
identity as identity,
11+
rot as rot,
12+
shift as shift,
13+
phase as phase,
14+
x as x,
15+
y as y,
16+
z as z,
17+
h as h,
18+
s as s,
19+
t as t,
20+
p0 as p0,
21+
p1 as p1,
22+
spin_n as spin_n,
23+
spin_p as spin_p,
24+
u as u,
25+
pauli_string as pauli_string,
26+
)
27+
from .stdlib import (
28+
rx as rx,
29+
ry as ry,
30+
rz as rz,
31+
cx as cx,
32+
cz as cz,
33+
cy as cy,
34+
ch as ch,
35+
cphase as cphase,
36+
)

0 commit comments

Comments
 (0)