Skip to content

Commit 1cd01ec

Browse files
authored
Fix some pyright issues (#163)
1 parent 1baaeab commit 1cd01ec

File tree

11 files changed

+82
-70
lines changed

11 files changed

+82
-70
lines changed

src/bloqade/analysis/address/impls.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,10 @@ def unwrap(
192192

193193
origin_qubit = frame.get(stmt.qubit)
194194

195-
return (AddressWire(origin_qubit=origin_qubit),)
195+
if isinstance(origin_qubit, AddressQubit):
196+
return (AddressWire(origin_qubit=origin_qubit),)
197+
else:
198+
return (Address.top(),)
196199

197200
@interp.impl(squin.wire.Apply)
198201
def apply(
@@ -201,14 +204,7 @@ def apply(
201204
frame: ForwardFrame[Address],
202205
stmt: squin.wire.Apply,
203206
):
204-
205-
origin_qubits = tuple(
206-
[frame.get(input_elem).origin_qubit for input_elem in stmt.inputs]
207-
)
208-
new_address_wires = tuple(
209-
[AddressWire(origin_qubit=origin_qubit) for origin_qubit in origin_qubits]
210-
)
211-
return new_address_wires
207+
return frame.get_values(stmt.inputs)
212208

213209

214210
@squin.qubit.dialect.register(key="qubit.address")

src/bloqade/noise/native/model.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,10 +102,9 @@ class MoveNoiseModelABC(abc.ABC):
102102
params: MoveNoiseParams = field(default_factory=MoveNoiseParams)
103103
"""Parameters for calculating move noise."""
104104

105-
@classmethod
106105
@abc.abstractmethod
107106
def parallel_cz_errors(
108-
cls, ctrls: List[int], qargs: List[int], rest: List[int]
107+
self, ctrls: List[int], qargs: List[int], rest: List[int]
109108
) -> Dict[Tuple[float, float, float, float], List[int]]:
110109
"""Takes a set of ctrls and qargs and returns a noise model for all qubits."""
111110
pass

src/bloqade/qasm2/parse/lowering.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,9 @@ def lower_literal(self, state: lowering.State[ast.Node], value) -> ir.SSAValue:
200200
elif isinstance(value, float):
201201
stmt = expr.ConstFloat(value=value)
202202
else:
203-
raise lowering.BuildError(f"Unsupported literal type {type(value)}")
204-
203+
raise lowering.BuildError(
204+
f"Expected value of type float or int, got {type(value)}."
205+
)
205206
state.current_frame.push(stmt)
206207
return stmt.result
207208

@@ -216,6 +217,8 @@ def visit_MainProgram(self, state: lowering.State[ast.Node], node: ast.MainProgr
216217
dialects = ["qasm2.core", "qasm2.uop", "qasm2.expr"]
217218
elif isinstance(node.header, ast.Kirin):
218219
dialects = node.header.dialects
220+
else:
221+
raise lowering.BuildError(f"Unexpected node header {node.header}")
219222

220223
for dialect in dialects:
221224
if dialect not in allowed:
@@ -412,6 +415,8 @@ def visit_Bit(self, state: lowering.State[ast.Node], node: ast.Bit):
412415
stmt = core.QRegGet(reg, addr.result)
413416
elif reg.type.is_subseteq(CRegType):
414417
stmt = core.CRegGet(reg, addr.result)
418+
else:
419+
raise lowering.BuildError(f"Unexpected register type {reg.type}")
415420
return state.current_frame.push(stmt).result
416421

417422
def visit_Call(self, state: lowering.State[ast.Node], node: ast.Call):

src/bloqade/qasm2/passes/fold.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from kirin.analysis import const
2020
from kirin.dialects import scf, ilist
2121
from kirin.ir.method import Method
22-
from kirin.rewrite.abc import RewriteResult
22+
from kirin.rewrite.result import RewriteResult
2323

2424
from bloqade.qasm2.dialects import expr
2525

src/bloqade/qasm2/rewrite/heuristic_noise.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Tuple
1+
from typing import Dict, List, Tuple, cast
22
from dataclasses import field, dataclass
33

44
from kirin import ir
@@ -226,8 +226,12 @@ def rewrite_parallel_cz_gate(self, node: parallel.CZ):
226226
and isinstance(qargs, address.AddressTuple)
227227
and all(isinstance(addr, address.AddressQubit) for addr in qargs.data)
228228
):
229-
ctrl_qubits = list(map(lambda addr: addr.data, ctrls.data))
230-
qarg_qubits = list(map(lambda addr: addr.data, qargs.data))
229+
ctrl_qubits = list(
230+
map(lambda addr: cast(address.AddressQubit, addr).data, ctrls.data)
231+
)
232+
qarg_qubits = list(
233+
map(lambda addr: cast(address.AddressQubit, addr).data, qargs.data)
234+
)
231235
rest = sorted(
232236
set(self.qubit_ssa_value.keys()) - set(ctrl_qubits + qarg_qubits)
233237
)

src/bloqade/qasm2/rewrite/native_gates.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,15 @@ def _circuit_diagram_info_(self, args):
7070
return "*", "CU"
7171

7272

73-
def around(val):
73+
def around(val) -> float:
7474
return float(np.around(val, 14))
7575

7676

7777
def one_qubit_gate_to_u3_angles(op: cirq.Operation) -> tuple[float, float, float]:
7878
lam, theta, phi = ( # Z angle, Y angle, then Z angle
7979
cirq.deconstruct_single_qubit_matrix_into_angles(cirq.unitary(op))
8080
)
81-
return tuple(map(around, (theta, phi, lam)))
81+
return around(theta), around(phi), around(lam)
8282

8383

8484
@dataclass

src/bloqade/qasm2/rewrite/register.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from kirin import ir
22
from kirin.dialects import py
3-
from kirin.rewrite.abc import RewriteRule, RewriteResult
3+
from kirin.rewrite.abc import RewriteRule
4+
from kirin.rewrite.result import RewriteResult
45

56
from bloqade.qasm2.dialects import core
67

src/bloqade/qasm2/rewrite/uop_to_parallel.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
from dataclasses import field, dataclass
44

55
from kirin import ir
6-
from kirin.rewrite import abc as rewrite_abc
76
from kirin.dialects import py, ilist
7+
from kirin.rewrite.abc import RewriteRule
88
from kirin.analysis.const import lattice
9+
from kirin.rewrite.result import RewriteResult
910

1011
from bloqade.analysis import address
1112
from bloqade.qasm2.dialects import uop, core, parallel
@@ -14,7 +15,7 @@
1415

1516
class MergePolicyABC(abc.ABC):
1617
@abc.abstractmethod
17-
def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
18+
def __call__(self, node: ir.Statement) -> RewriteResult:
1819
pass
1920

2021
@classmethod
@@ -141,10 +142,10 @@ def from_analysis(
141142
group_numbers=group_numbers,
142143
)
143144

144-
def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
145+
def __call__(self, node: ir.Statement) -> RewriteResult:
145146

146147
if node not in self.group_numbers:
147-
return rewrite_abc.RewriteResult()
148+
return RewriteResult()
148149

149150
group_number = self.group_numbers[node]
150151
group = self.merge_groups[group_number]
@@ -157,9 +158,7 @@ def __call__(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
157158
if self.group_has_merged[group_number]:
158159
node.delete()
159160

160-
return rewrite_abc.RewriteResult(
161-
has_done_something=self.group_has_merged[group_number]
162-
)
161+
return RewriteResult(has_done_something=self.group_has_merged[group_number])
163162

164163
def move_and_collect_qubit_list(
165164
self, qargs: List[ir.SSAValue], node: ir.Statement
@@ -219,14 +218,14 @@ def rewrite_group_cz(self, node: ir.Statement, group: List[ir.Statement]):
219218
ctrls.append(stmt.ctrls)
220219
qargs.append(stmt.qargs)
221220
else:
222-
return rewrite_abc.RewriteResult(has_done_something=False)
221+
return RewriteResult(has_done_something=False)
223222

224223
ctrls_values = self.move_and_collect_qubit_list(ctrls, node)
225224
qargs_values = self.move_and_collect_qubit_list(qargs, node)
226225

227226
if ctrls_values is None or qargs_values is None:
228227
# give up if we cannot determine the address or cannot move the qubits
229-
return rewrite_abc.RewriteResult(has_done_something=False)
228+
return RewriteResult(has_done_something=False)
230229

231230
new_ctrls = ilist.New(values=ctrls_values)
232231
new_qargs = ilist.New(values=qargs_values)
@@ -238,7 +237,7 @@ def rewrite_group_cz(self, node: ir.Statement, group: List[ir.Statement]):
238237

239238
node.delete()
240239

241-
return rewrite_abc.RewriteResult(has_done_something=True)
240+
return RewriteResult(has_done_something=True)
242241

243242
def rewrite_group_U(self, node: ir.Statement, group: List[ir.Statement]):
244243
return self.rewrite_group_u(node, group)
@@ -252,13 +251,13 @@ def rewrite_group_u(self, node: ir.Statement, group: List[ir.Statement]):
252251
elif isinstance(stmt, parallel.UGate):
253252
qargs.append(stmt.qargs)
254253
else:
255-
return rewrite_abc.RewriteResult(has_done_something=False)
254+
return RewriteResult(has_done_something=False)
256255

257256
assert isinstance(node, (uop.UGate, parallel.UGate))
258257
qargs_values = self.move_and_collect_qubit_list(qargs, node)
259258

260259
if qargs_values is None:
261-
return rewrite_abc.RewriteResult(has_done_something=False)
260+
return RewriteResult(has_done_something=False)
262261

263262
new_qargs = ilist.New(values=qargs_values)
264263
new_gate = parallel.UGate(
@@ -271,7 +270,7 @@ def rewrite_group_u(self, node: ir.Statement, group: List[ir.Statement]):
271270
new_gate.insert_before(node)
272271
node.delete()
273272

274-
return rewrite_abc.RewriteResult(has_done_something=True)
273+
return RewriteResult(has_done_something=True)
275274

276275
def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
277276
qargs = []
@@ -282,14 +281,14 @@ def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
282281
elif isinstance(stmt, parallel.RZ):
283282
qargs.append(stmt.qargs)
284283
else:
285-
return rewrite_abc.RewriteResult(has_done_something=False)
284+
return RewriteResult(has_done_something=False)
286285

287286
assert isinstance(node, (uop.RZ, parallel.RZ))
288287

289288
qargs_values = self.move_and_collect_qubit_list(qargs, node)
290289

291290
if qargs_values is None:
292-
return rewrite_abc.RewriteResult(has_done_something=False)
291+
return RewriteResult(has_done_something=False)
293292

294293
new_qargs = ilist.New(values=qargs_values)
295294
new_gate = parallel.RZ(
@@ -300,7 +299,7 @@ def rewrite_group_rz(self, node: ir.Statement, group: List[ir.Statement]):
300299
new_gate.insert_before(node)
301300
node.delete()
302301

303-
return rewrite_abc.RewriteResult(has_done_something=True)
302+
return RewriteResult(has_done_something=True)
304303

305304
def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]):
306305
qargs = []
@@ -310,13 +309,13 @@ def rewrite_group_barrier(self, node: uop.Barrier, group: List[uop.Barrier]):
310309
qargs_values = self.move_and_collect_qubit_list(qargs, node)
311310

312311
if qargs_values is None:
313-
return rewrite_abc.RewriteResult(has_done_something=False)
312+
return RewriteResult(has_done_something=False)
314313

315314
new_node = uop.Barrier(qargs=qargs_values)
316315
new_node.insert_before(node)
317316
node.delete()
318317

319-
return rewrite_abc.RewriteResult(has_done_something=True)
318+
return RewriteResult(has_done_something=True)
320319

321320

322321
class GreedyMixin(MergePolicyABC):
@@ -385,11 +384,11 @@ class SimpleOptimalMergePolicy(OptimalMixIn, SimpleMergePolicy):
385384

386385

387386
@dataclass
388-
class UOpToParallelRule(rewrite_abc.RewriteRule):
387+
class UOpToParallelRule(RewriteRule):
389388
merge_rewriters: Dict[ir.Block | None, MergePolicyABC]
390389

391-
def rewrite_Statement(self, node: ir.Statement) -> rewrite_abc.RewriteResult:
390+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
392391
merge_rewriter = self.merge_rewriters.get(
393-
node.parent_block, lambda _: rewrite_abc.RewriteResult()
392+
node.parent_block, lambda _: RewriteResult()
394393
)
395394
return merge_rewriter(node)

src/bloqade/qbraid/schema.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class Operation(BaseModel, frozen=True, extra="forbid"):
99
op_type: str = Field(init=False)
1010

1111

12-
class CZ(Operation):
12+
class CZ(Operation, frozen=True):
1313
"""A CZ gate operation.
1414
1515
Fields:
@@ -22,7 +22,7 @@ class CZ(Operation):
2222
participants: Tuple[Union[Tuple[int], Tuple[int, int]], ...]
2323

2424

25-
class GlobalRz(Operation):
25+
class GlobalRz(Operation, frozen=True):
2626
"""GlobalRz operation.
2727
2828
Fields:
@@ -34,7 +34,7 @@ class GlobalRz(Operation):
3434
phi: float
3535

3636

37-
class GlobalW(Operation):
37+
class GlobalW(Operation, frozen=True):
3838
"""GlobalW operation.
3939
4040
Fields:
@@ -48,7 +48,7 @@ class GlobalW(Operation):
4848
phi: float
4949

5050

51-
class LocalRz(Operation):
51+
class LocalRz(Operation, frozen=True):
5252
"""LocalRz operation.
5353
5454
Fields:
@@ -63,7 +63,7 @@ class LocalRz(Operation):
6363
phi: float
6464

6565

66-
class LocalW(Operation):
66+
class LocalW(Operation, frozen=True):
6767
"""LocalW operation.
6868
6969
Fields:
@@ -80,7 +80,7 @@ class LocalW(Operation):
8080
phi: float
8181

8282

83-
class Measurement(Operation):
83+
class Measurement(Operation, frozen=True):
8484
"""Measurement operation.
8585
8686
Fields:
@@ -95,9 +95,7 @@ class Measurement(Operation):
9595
participants: Tuple[int, ...]
9696

9797

98-
OperationType = TypeVar(
99-
"OperationType", bound=Union[CZ, GlobalRz, GlobalW, LocalRz, LocalW, Measurement]
100-
)
98+
OperationType = CZ | GlobalRz | GlobalW | LocalRz | LocalW | Measurement
10199

102100

103101
class ErrorModel(BaseModel, frozen=True, extra="forbid"):
@@ -106,7 +104,7 @@ class ErrorModel(BaseModel, frozen=True, extra="forbid"):
106104
error_model_type: str = Field(init=False)
107105

108106

109-
class PauliErrorModel(ErrorModel):
107+
class PauliErrorModel(ErrorModel, frozen=True):
110108
"""Pauli error model.
111109
112110
Fields:
@@ -131,7 +129,7 @@ class ErrorOperation(BaseModel, Generic[ErrorModelType], frozen=True, extra="for
131129
survival_prob: Tuple[float, ...]
132130

133131

134-
class CZError(ErrorOperation[ErrorModelType]):
132+
class CZError(ErrorOperation[ErrorModelType], frozen=True):
135133
"""CZError operation.
136134
137135
Fields:
@@ -149,7 +147,7 @@ class CZError(ErrorOperation[ErrorModelType]):
149147
single_error: ErrorModelType
150148

151149

152-
class SingleQubitError(ErrorOperation[ErrorModelType]):
150+
class SingleQubitError(ErrorOperation[ErrorModelType], frozen=True):
153151
"""SingleQubitError operation.
154152
155153
Fields:

src/bloqade/squin/analysis/nsites/analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ def eval_stmt(self, frame: ForwardFrame, stmt: ir.Statement):
2323
if method is not None:
2424
return method(self, frame, stmt)
2525
elif stmt.has_trait(HasSites):
26-
has_sites_trait = stmt.get_trait(HasSites)
26+
has_sites_trait = stmt.get_present_trait(HasSites)
2727
sites = has_sites_trait.get_sites(stmt)
2828
return (NumberSites(sites=sites),)
2929
elif stmt.has_trait(FixedSites):
30-
sites_trait = stmt.get_trait(FixedSites)
30+
sites_trait = stmt.get_present_trait(FixedSites)
3131
return (NumberSites(sites=sites_trait.data),)
3232
else:
3333
return (NoSites(),)

0 commit comments

Comments
 (0)