Skip to content

Commit 11110b3

Browse files
authored
Clean up import path (#279)
This PR address #276, remove all `from xxx import *` General pattern: 1. to access the Statement definition: ```python from stim.dialects import collapse, auxiliary, gate gate.CX # the statement def auxiliary.ObservableInclude # the statement def ``` 2. for general user access, they should not touch statement, so: ```python from bloqade import stim stim.cx # wrapper stim.observable_include # wrapper ```
1 parent 23c995b commit 11110b3

31 files changed

+446
-265
lines changed
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
11
from . import _emit as _emit, address as address, _typeinfer as _typeinfer
2-
from .stmts import * # noqa: F403
2+
from .stmts import (
3+
Reset as Reset,
4+
CRegEq as CRegEq,
5+
CRegGet as CRegGet,
6+
CRegNew as CRegNew,
7+
Measure as Measure,
8+
QRegGet as QRegGet,
9+
QRegNew as QRegNew,
10+
)
311
from ._dialect import dialect as dialect
Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
11
from . import _emit as _emit, _interp as _interp, _from_python as _from_python
2-
from .stmts import * # noqa: F403
2+
from .stmts import (
3+
Add as Add,
4+
Cos as Cos,
5+
Div as Div,
6+
Exp as Exp,
7+
Log as Log,
8+
Mul as Mul,
9+
Neg as Neg,
10+
Pow as Pow,
11+
Sin as Sin,
12+
Sub as Sub,
13+
Tan as Tan,
14+
Sqrt as Sqrt,
15+
ConstPI as ConstPI,
16+
ConstInt as ConstInt,
17+
ConstFloat as ConstFloat,
18+
GateFunction as GateFunction,
19+
)
320
from ._dialect import dialect as dialect

src/bloqade/qasm2/dialects/uop/__init__.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,40 @@
1-
from . import _emit as _emit, stmts as stmts
2-
from .stmts import * # noqa: F403
1+
from . import _emit as _emit, stmts as stmts, schedule as schedule
2+
from .stmts import (
3+
CH as CH,
4+
CU as CU,
5+
CX as CX,
6+
CY as CY,
7+
CZ as CZ,
8+
RX as RX,
9+
RY as RY,
10+
RZ as RZ,
11+
SX as SX,
12+
U1 as U1,
13+
U2 as U2,
14+
CCX as CCX,
15+
CRX as CRX,
16+
CRY as CRY,
17+
CRZ as CRZ,
18+
CSX as CSX,
19+
CU1 as CU1,
20+
CU3 as CU3,
21+
RXX as RXX,
22+
RZZ as RZZ,
23+
H as H,
24+
S as S,
25+
T as T,
26+
X as X,
27+
Y as Y,
28+
Z as Z,
29+
Id as Id,
30+
Sdag as Sdag,
31+
Swap as Swap,
32+
Tdag as Tdag,
33+
CSwap as CSwap,
34+
SXdag as SXdag,
35+
UGate as UGate,
36+
Barrier as Barrier,
37+
SingleQubitGate as SingleQubitGate,
38+
TwoQubitCtrlGate as TwoQubitCtrlGate,
39+
)
340
from ._dialect import dialect as dialect
4-
from .schedule import * # noqa: F403

src/bloqade/qasm2/dialects/uop/schedule.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99

1010
@dialect.register(key="qasm2.schedule.dag")
11-
class UOp(interp.MethodTable):
11+
class UOpSchedule(interp.MethodTable):
1212

1313
@interp.impl(stmts.Id)
1414
@interp.impl(stmts.SXdag)

src/bloqade/qasm2/rewrite/desugar.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,17 @@
55
from kirin.rewrite import abc, walk
66
from kirin.dialects import py
77

8+
from bloqade.qasm2 import types
89
from bloqade.qasm2.dialects import core
910

1011

1112
class IndexingDesugarRule(abc.RewriteRule):
1213
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
1314
if isinstance(node, py.indexing.GetItem):
14-
if node.obj.type.is_subseteq(core.QRegType):
15+
if node.obj.type.is_subseteq(types.QRegType):
1516
node.replace_by(core.QRegGet(reg=node.obj, idx=node.index))
1617
return abc.RewriteResult(has_done_something=True)
17-
elif node.obj.type.is_subseteq(core.CRegType):
18+
elif node.obj.type.is_subseteq(types.CRegType):
1819
node.replace_by(core.CRegGet(reg=node.obj, idx=node.index))
1920
return abc.RewriteResult(has_done_something=True)
2021

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,8 @@
1-
# Put all the proper wrappers here
2-
3-
from kirin.lowering import wraps as _wraps
4-
5-
from bloqade.squin.op.types import Op
6-
71
from . import stmts as stmts
8-
9-
10-
@_wraps(stmts.PauliError)
11-
def pauli_error(basis: Op, p: float) -> Op: ...
12-
13-
14-
@_wraps(stmts.PPError)
15-
def pp_error(op: Op, p: float) -> Op: ...
16-
17-
18-
@_wraps(stmts.Depolarize)
19-
def depolarize(n_qubits: int, p: float) -> Op: ...
20-
21-
22-
@_wraps(stmts.PauliChannel)
23-
def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
24-
25-
26-
@_wraps(stmts.QubitLoss)
27-
def qubit_loss(p: float) -> Op: ...
2+
from ._dialect import dialect as dialect
3+
from ._wrapper import (
4+
pp_error as pp_error,
5+
depolarize as depolarize,
6+
qubit_loss as qubit_loss,
7+
pauli_channel as pauli_channel,
8+
)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from kirin.lowering import wraps
2+
3+
from bloqade.squin.op.types import Op
4+
5+
from . import stmts
6+
7+
8+
@wraps(stmts.PauliError)
9+
def pauli_error(basis: Op, p: float) -> Op: ...
10+
11+
12+
@wraps(stmts.PPError)
13+
def pp_error(op: Op, p: float) -> Op: ...
14+
15+
16+
@wraps(stmts.Depolarize)
17+
def depolarize(n_qubits: int, p: float) -> Op: ...
18+
19+
20+
@wraps(stmts.PauliChannel)
21+
def pauli_channel(n_qubits: int, params: tuple[float, ...]) -> Op: ...
22+
23+
24+
@wraps(stmts.QubitLoss)
25+
def qubit_loss(p: float) -> Op: ...

src/bloqade/squin/op/__init__.py

Lines changed: 33 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -1,162 +1,36 @@
1-
from kirin import ir as _ir
2-
from kirin.prelude import structural_no_opt as _structural_no_opt
3-
from kirin.lowering import wraps as _wraps
4-
51
from . import stmts as stmts, types as types, rewrite as rewrite
2+
from .stdlib import (
3+
ch as ch,
4+
cx as cx,
5+
cy as cy,
6+
cz as cz,
7+
rx as rx,
8+
ry as ry,
9+
rz as rz,
10+
cphase as cphase,
11+
)
612
from .traits import Unitary as Unitary, MaybeUnitary as MaybeUnitary
713
from ._dialect import dialect as dialect
8-
9-
10-
@_wraps(stmts.Kron)
11-
def kron(lhs: types.Op, rhs: types.Op) -> types.Op: ...
12-
13-
14-
@_wraps(stmts.Mult)
15-
def mult(lhs: types.Op, rhs: types.Op) -> types.Op: ...
16-
17-
18-
@_wraps(stmts.Scale)
19-
def scale(op: types.Op, factor: complex) -> types.Op: ...
20-
21-
22-
@_wraps(stmts.Adjoint)
23-
def adjoint(op: types.Op) -> types.Op: ...
24-
25-
26-
@_wraps(stmts.Control)
27-
def control(op: types.Op, *, n_controls: int) -> types.Op:
28-
"""
29-
Create a controlled operator.
30-
31-
Note, that when considering atom loss, the operator will not be applied if
32-
any of the controls has been lost.
33-
34-
Args:
35-
operator: The operator to apply under the control.
36-
n_controls: The number qubits to be used as control.
37-
38-
Returns:
39-
Operator
40-
"""
41-
...
42-
43-
44-
@_wraps(stmts.Identity)
45-
def identity(*, sites: int) -> types.Op: ...
46-
47-
48-
@_wraps(stmts.Rot)
49-
def rot(axis: types.Op, angle: float) -> types.Op: ...
50-
51-
52-
@_wraps(stmts.ShiftOp)
53-
def shift(theta: float) -> types.Op: ...
54-
55-
56-
@_wraps(stmts.PhaseOp)
57-
def phase(theta: float) -> types.Op: ...
58-
59-
60-
@_wraps(stmts.X)
61-
def x() -> types.Op: ...
62-
63-
64-
@_wraps(stmts.Y)
65-
def y() -> types.Op: ...
66-
67-
68-
@_wraps(stmts.Z)
69-
def z() -> types.Op: ...
70-
71-
72-
@_wraps(stmts.H)
73-
def h() -> types.Op: ...
74-
75-
76-
@_wraps(stmts.S)
77-
def s() -> types.Op: ...
78-
79-
80-
@_wraps(stmts.T)
81-
def t() -> types.Op: ...
82-
83-
84-
@_wraps(stmts.P0)
85-
def p0() -> types.Op: ...
86-
87-
88-
@_wraps(stmts.P1)
89-
def p1() -> types.Op: ...
90-
91-
92-
@_wraps(stmts.Sn)
93-
def spin_n() -> types.Op: ...
94-
95-
96-
@_wraps(stmts.Sp)
97-
def spin_p() -> types.Op: ...
98-
99-
100-
@_wraps(stmts.U3)
101-
def u(theta: float, phi: float, lam: float) -> types.Op: ...
102-
103-
104-
@_wraps(stmts.PauliString)
105-
def pauli_string(*, string: str) -> types.Op: ...
106-
107-
108-
# stdlibs
109-
@_ir.dialect_group(_structural_no_opt.add(dialect))
110-
def op(self):
111-
def run_pass(method):
112-
pass
113-
114-
return run_pass
115-
116-
117-
@op
118-
def rx(theta: float) -> types.Op:
119-
"""Rotation X gate."""
120-
return rot(x(), theta)
121-
122-
123-
@op
124-
def ry(theta: float) -> types.Op:
125-
"""Rotation Y gate."""
126-
return rot(y(), theta)
127-
128-
129-
@op
130-
def rz(theta: float) -> types.Op:
131-
"""Rotation Z gate."""
132-
return rot(z(), theta)
133-
134-
135-
@op
136-
def cx() -> types.Op:
137-
"""Controlled X gate."""
138-
return control(x(), n_controls=1)
139-
140-
141-
@op
142-
def cy() -> types.Op:
143-
"""Controlled Y gate."""
144-
return control(y(), n_controls=1)
145-
146-
147-
@op
148-
def cz() -> types.Op:
149-
"""Control Z gate."""
150-
return control(z(), n_controls=1)
151-
152-
153-
@op
154-
def ch() -> types.Op:
155-
"""Control H gate."""
156-
return control(h(), n_controls=1)
157-
158-
159-
@op
160-
def cphase(theta: float) -> types.Op:
161-
"""Control Phase gate."""
162-
return control(phase(theta), n_controls=1)
14+
from ._wrapper import (
15+
h as h,
16+
s as s,
17+
t as t,
18+
u as u,
19+
x as x,
20+
y as y,
21+
z as z,
22+
p0 as p0,
23+
p1 as p1,
24+
rot as rot,
25+
kron as kron,
26+
mult as mult,
27+
phase as phase,
28+
scale as scale,
29+
shift as shift,
30+
spin_n as spin_n,
31+
spin_p as spin_p,
32+
adjoint as adjoint,
33+
control as control,
34+
identity as identity,
35+
pauli_string as pauli_string,
36+
)

0 commit comments

Comments
 (0)