Skip to content

Commit 16d6f84

Browse files
authored
upgrade to kirin 0.17 (#178)
1 parent 6869301 commit 16d6f84

File tree

31 files changed

+188
-194
lines changed

31 files changed

+188
-194
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ requires-python = ">=3.10"
1313
dependencies = [
1414
"numpy>=1.22.0",
1515
"scipy>=1.13.1",
16-
"kirin-toolchain~=0.16.0",
16+
"kirin-toolchain~=0.17.0",
1717
"rich>=13.9.4",
1818
"pydantic>=1.3.0,<2.11.0",
1919
"pandas>=2.2.3",

src/bloqade/noise/native/rewrite.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
from dataclasses import dataclass
22

33
from kirin import ir
4-
from kirin.rewrite import abc, dce, walk, result, fixpoint
4+
from kirin.rewrite import abc, dce, walk, fixpoint
55
from kirin.passes.abc import Pass
66

77
from .stmts import PauliChannel, CZPauliChannel, AtomLossChannel
88
from ._dialect import dialect
99

1010

1111
class RemoveNoiseRewrite(abc.RewriteRule):
12-
def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
12+
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
1313
if isinstance(node, (AtomLossChannel, PauliChannel, CZPauliChannel)):
1414
node.delete()
15-
return result.RewriteResult(has_done_something=True)
15+
return abc.RewriteResult(has_done_something=True)
1616

17-
return result.RewriteResult()
17+
return abc.RewriteResult()
1818

1919

2020
@dataclass
2121
class RemoveNoisePass(Pass):
2222
name = "remove-noise"
2323

24-
def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
24+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
2525
delete_walk = walk.Walk(RemoveNoiseRewrite())
2626
dce_walk = fixpoint.Fixpoint(walk.Walk(dce.DeadCodeElimination()))
2727

src/bloqade/pyqrack/target.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def run(
8080
"""
8181
fold = Fold(mt.dialects)
8282
fold(mt)
83-
return self._get_interp(mt).run(mt, args, kwargs).expect()
83+
return self._get_interp(mt).run(mt, args, kwargs)
8484

8585
def multi_run(
8686
self,
@@ -107,6 +107,6 @@ def multi_run(
107107
interpreter = self._get_interp(mt)
108108
batched_results = []
109109
for _ in range(_shots):
110-
batched_results.append(interpreter.run(mt, args, kwargs).expect())
110+
batched_results.append(interpreter.run(mt, args, kwargs))
111111

112112
return batched_results

src/bloqade/qasm2/_wrappers.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -61,33 +61,25 @@ def reset(qarg: Qubit) -> None:
6161

6262

6363
@overload
64-
def measure(qarg: Qubit, cbit: Bit) -> None:
65-
"""
66-
Measure the qubit `qarg` and store the result in the classical bit `cbit`.
67-
68-
Args:
69-
qarg: The qubit to measure.
70-
cbit: The classical bit to store the result in.
71-
"""
72-
...
64+
def measure(qreg: QReg, creg: CReg) -> None: ...
7365

7466

7567
@overload
76-
def measure(qarg: QReg, carg: CReg) -> None:
68+
def measure(qarg: Qubit, cbit: Bit) -> None: ...
69+
70+
71+
@wraps(core.Measure)
72+
def measure(qarg, cbit) -> None:
7773
"""
78-
Measure each qubit in the quantum register `qarg` and store the result in the classical register `carg`.
74+
Measure the qubit `qarg` and store the result in the classical bit `cbit`.
7975
8076
Args:
81-
qarg: The quantum register to measure.
82-
carg: The classical bit to store the result in.
77+
qarg: The qubit to measure.
78+
cbit: The classical bit to store the result in.
8379
"""
8480
...
8581

8682

87-
@wraps(core.Measure)
88-
def measure(qarg, carg) -> None: ...
89-
90-
9183
@wraps(uop.CX)
9284
def cx(ctrl: Qubit, qarg: Qubit) -> None:
9385
"""

src/bloqade/qasm2/dialects/expr/_emit.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@ def emit_func(
1919
self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.GateFunction
2020
):
2121

22-
cparams, qparams = [], []
22+
args, cparams, qparams = [], [], []
2323
for arg in stmt.body.blocks[0].args[1:]:
24-
name = frame.get(arg)
24+
name = frame.get_typed(arg, ast.Name)
25+
args.append(name)
2526
if not isinstance(name, ast.Name):
2627
raise EmitError("expected ast.Name")
2728
if arg.type.is_subseteq(QubitType):
2829
qparams.append(name.id)
2930
else:
3031
cparams.append(name.id)
31-
emit.run_ssacfg_region(frame, stmt.body)
32+
emit.run_ssacfg_region(frame, stmt.body, args)
3233
emit.output = ast.Gate(
3334
name=stmt.sym_name,
3435
cparams=cparams,

src/bloqade/qasm2/emit/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ def initialize(self) -> Self:
3636
)
3737
return self
3838

39-
def new_frame(self, code: ir.Statement) -> EmitQASM2Frame:
40-
return EmitQASM2Frame.from_func_like(code)
39+
def initialize_frame(
40+
self, code: ir.Statement, *, has_parent_access: bool = False
41+
) -> EmitQASM2Frame[StmtType]:
42+
return EmitQASM2Frame(code, has_parent_access=has_parent_access)
4143

4244
def run_method(
4345
self, method: ir.Method, args: tuple[ast.Node | None, ...]

src/bloqade/qasm2/emit/gate.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,16 @@ def ignore(self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt):
9191
def emit_func(
9292
self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: func.Function
9393
):
94-
emit.run_ssacfg_region(frame, stmt.body)
94+
args_ssa = stmt.args
95+
print(stmt.args)
96+
emit.run_ssacfg_region(frame, stmt.body, frame.get_values(args_ssa))
97+
9598
cparams, qparams = [], []
96-
for arg in stmt.args:
99+
for arg in args_ssa:
97100
if arg.type.is_subseteq(QubitType):
98101
qparams.append(frame.get(arg))
99102
else:
100103
cparams.append(frame.get(arg))
104+
101105
emit.output = ast.Gate(stmt.sym_name, cparams, qparams, frame.body)
102106
return ()

src/bloqade/qasm2/emit/main.py

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def emit_func(
2424
):
2525
from bloqade.qasm2.dialects import glob, noise, parallel
2626

27-
emit.run_ssacfg_region(frame, stmt.body)
27+
emit.run_ssacfg_region(frame, stmt.body, ())
2828
if emit.dialects.data.intersection(
2929
(parallel.dialect, glob.dialect, noise.dialect)
3030
):
@@ -51,12 +51,14 @@ def emit_conditional_branch(
5151
self, emit: EmitQASM2Main, frame: EmitQASM2Frame, stmt: cf.ConditionalBranch
5252
):
5353
cond = emit.assert_node(ast.Cmp, frame.get(stmt.cond))
54-
body_frame = emit.new_frame(stmt)
55-
body_frame.entries.update(frame.entries)
56-
body_frame.set_values(
57-
stmt.then_successor.args, frame.get_values(stmt.then_arguments)
58-
)
59-
emit.emit_block(body_frame, stmt.then_successor)
54+
55+
with emit.new_frame(stmt) as body_frame:
56+
body_frame.entries.update(frame.entries)
57+
body_frame.set_values(
58+
stmt.then_successor.args, frame.get_values(stmt.then_arguments)
59+
)
60+
emit.emit_block(body_frame, stmt.then_successor)
61+
6062
frame.body.append(
6163
ast.IfStmt(
6264
cond,
@@ -91,15 +93,17 @@ def emit_if_else(
9193
)
9294

9395
cond = emit.assert_node(ast.Cmp, frame.get(stmt.cond))
94-
then_frame = emit.new_frame(stmt)
95-
then_frame.entries.update(frame.entries)
96-
emit.emit_block(then_frame, stmt.then_body.blocks[0])
97-
frame.body.append(
98-
ast.IfStmt(
99-
cond,
100-
body=then_frame.body, # type: ignore
96+
97+
with emit.new_frame(stmt) as then_frame:
98+
then_frame.entries.update(frame.entries)
99+
emit.emit_block(then_frame, stmt.then_body.blocks[0])
100+
frame.body.append(
101+
ast.IfStmt(
102+
cond,
103+
body=then_frame.body, # type: ignore
104+
)
101105
)
102-
)
106+
103107
term = stmt.then_body.blocks[0].last_stmt
104108
if isinstance(term, scf.Yield):
105109
return then_frame.get_values(term.values)

src/bloqade/qasm2/emit/target.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,9 +101,7 @@ def emit(self, entry: ir.Method) -> ast.MainProgram:
101101

102102
Py2QASM(entry.dialects)(entry)
103103
target_main = EmitQASM2Main(self.main_target)
104-
target_main.run(
105-
entry, tuple(ast.Name(name) for name in entry.arg_names[1:])
106-
).expect()
104+
target_main.run(entry, tuple(ast.Name(name) for name in entry.arg_names[1:]))
107105

108106
main_program = target_main.output
109107
assert main_program is not None, f"failed to emit {entry.sym_name}"
@@ -133,9 +131,7 @@ def emit(self, entry: ir.Method) -> ast.MainProgram:
133131

134132
Py2QASM(fn.dialects)(fn)
135133

136-
target_gate.run(
137-
fn, tuple(ast.Name(name) for name in fn.arg_names[1:])
138-
).expect()
134+
target_gate.run(fn, tuple(ast.Name(name) for name in fn.arg_names[1:]))
139135
assert target_gate.output is not None, f"failed to emit {fn.sym_name}"
140136
extra.append(target_gate.output)
141137

src/bloqade/qasm2/passes/fold.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from kirin.analysis import const
2020
from kirin.dialects import scf, ilist
2121
from kirin.ir.method import Method
22-
from kirin.rewrite.result import RewriteResult
22+
from kirin.rewrite.abc import RewriteResult
2323

2424
from bloqade.qasm2.dialects import expr
2525

0 commit comments

Comments
 (0)