Skip to content

Commit dafd4e0

Browse files
committed
upgrade to kirin 0.17
1 parent cf93f91 commit dafd4e0

File tree

17 files changed

+126
-117
lines changed

17 files changed

+126
-117
lines changed

src/bloqade/noise/native/rewrite.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
11
from dataclasses import dataclass
22

33
from kirin import ir
4-
from kirin.rewrite import abc, dce, walk, result, fixpoint
4+
from kirin.rewrite import abc, dce, walk, fixpoint
55
from kirin.passes.abc import Pass
66

77
from .stmts import PauliChannel, CZPauliChannel, AtomLossChannel
88
from ._dialect import dialect
99

1010

1111
class RemoveNoiseRewrite(abc.RewriteRule):
12-
def rewrite_Statement(self, node: ir.Statement) -> result.RewriteResult:
12+
def rewrite_Statement(self, node: ir.Statement) -> abc.RewriteResult:
1313
if isinstance(node, (AtomLossChannel, PauliChannel, CZPauliChannel)):
1414
node.delete()
15-
return result.RewriteResult(has_done_something=True)
15+
return abc.RewriteResult(has_done_something=True)
1616

17-
return result.RewriteResult()
17+
return abc.RewriteResult()
1818

1919

2020
@dataclass
2121
class RemoveNoisePass(Pass):
2222
name = "remove-noise"
2323

24-
def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
24+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
2525
delete_walk = walk.Walk(RemoveNoiseRewrite())
2626
dce_walk = fixpoint.Fixpoint(walk.Walk(dce.DeadCodeElimination()))
2727

src/bloqade/qasm2/_wrappers.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import overload
2+
13
from kirin.lowering import wraps
24

35
from .types import Bit, CReg, QReg, Qubit
@@ -58,8 +60,16 @@ def reset(qarg: Qubit) -> None:
5860
...
5961

6062

63+
@overload
64+
def measure(qreg: QReg, creg: CReg) -> None: ...
65+
66+
67+
@overload
68+
def measure(qarg: Qubit, cbit: Bit) -> None: ...
69+
70+
6171
@wraps(core.Measure)
62-
def measure(qarg: Qubit, cbit: Bit) -> None:
72+
def measure(qarg, cbit) -> None:
6373
"""
6474
Measure the qubit `qarg` and store the result in the classical bit `cbit`.
6575

src/bloqade/qasm2/dialects/expr/_emit.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,17 @@ def emit_func(
1919
self, emit: EmitQASM2Gate, frame: EmitQASM2Frame, stmt: stmts.GateFunction
2020
):
2121

22-
cparams, qparams = [], []
22+
args, cparams, qparams = [], [], []
2323
for arg in stmt.body.blocks[0].args[1:]:
24-
name = frame.get(arg)
24+
name = frame.get_typed(arg, ast.Name)
25+
args.append(name)
2526
if not isinstance(name, ast.Name):
2627
raise EmitError("expected ast.Name")
2728
if arg.type.is_subseteq(QubitType):
2829
qparams.append(name.id)
2930
else:
3031
cparams.append(name.id)
31-
emit.run_ssacfg_region(frame, stmt.body)
32+
emit.run_ssacfg_region(frame, stmt.body, args)
3233
emit.output = ast.Gate(
3334
name=stmt.sym_name,
3435
cparams=cparams,

src/bloqade/qasm2/emit/base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,10 @@ def initialize(self) -> Self:
3636
)
3737
return self
3838

39-
def new_frame(self, code: ir.Statement) -> EmitQASM2Frame:
40-
return EmitQASM2Frame.from_func_like(code)
39+
def initialize_frame(
40+
self, code: ir.Statement, *, has_parent_access: bool = False
41+
) -> EmitQASM2Frame[StmtType]:
42+
return EmitQASM2Frame(code, has_parent_access=has_parent_access)
4143

4244
def run_method(
4345
self, method: ir.Method, args: tuple[ast.Node | None, ...]

src/bloqade/qasm2/emit/target.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,7 @@ def emit(self, entry: ir.Method) -> ast.MainProgram:
133133

134134
Py2QASM(fn.dialects)(fn)
135135

136-
target_gate.run(
137-
fn, tuple(ast.Name(name) for name in fn.arg_names[1:])
138-
).expect()
136+
target_gate.run(fn, tuple(ast.Name(name) for name in fn.arg_names[1:]))
139137
assert target_gate.output is not None, f"failed to emit {fn.sym_name}"
140138
extra.append(target_gate.output)
141139

src/bloqade/qasm2/passes/glob.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from kirin import ir
7-
from kirin.rewrite import cse, dce, walk, result
7+
from kirin.rewrite import abc, cse, dce, walk
88
from kirin.passes.abc import Pass
99
from kirin.passes.fold import Fold
1010
from kirin.rewrite.fixpoint import Fixpoint
@@ -54,7 +54,7 @@ def generate_rule(self, mt: ir.Method) -> GlobalToUOpRule:
5454
frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
5555
return GlobalToUOpRule(frame.entries)
5656

57-
def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
57+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
5858
rewriter = walk.Walk(self.generate_rule(mt))
5959
result = rewriter.rewrite(mt.code)
6060

@@ -106,7 +106,7 @@ def generate_rule(self, mt: ir.Method) -> GlobalToParallelRule:
106106
frame, _ = address.AddressAnalysis(mt.dialects).run_analysis(mt)
107107
return GlobalToParallelRule(frame.entries)
108108

109-
def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
109+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
110110
rewriter = walk.Walk(self.generate_rule(mt))
111111
result = rewriter.rewrite(mt.code)
112112

src/bloqade/qasm2/passes/noise.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
DeadCodeElimination,
1111
CommonSubexpressionElimination,
1212
)
13-
from kirin.rewrite.result import RewriteResult
13+
from kirin.rewrite.abc import RewriteResult
1414

1515
from bloqade.noise import native
1616
from bloqade.analysis import address

src/bloqade/qasm2/passes/parallel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
ConstantFold,
1717
DeadCodeElimination,
1818
CommonSubexpressionElimination,
19-
result,
19+
abc,
2020
)
2121
from kirin.analysis import const
2222

@@ -84,7 +84,7 @@ def generate_rule(self, mt: ir.Method) -> ParallelToUOpRule:
8484

8585
return ParallelToUOpRule(id_map=id_map, address_analysis=frame.entries)
8686

87-
def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
87+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
8888
result = Walk(self.generate_rule(mt)).rewrite(mt.code)
8989
rule = Chain(
9090
ConstantFold(),
@@ -140,7 +140,7 @@ def test():
140140
def __post_init__(self):
141141
self.constprop = const.Propagate(self.dialects)
142142

143-
def unsafe_run(self, mt: ir.Method) -> result.RewriteResult:
143+
def unsafe_run(self, mt: ir.Method) -> abc.RewriteResult:
144144
result = Walk(RaiseRegisterRule()).rewrite(mt.code)
145145

146146
# do not run the parallelization because registers are not at the top

src/bloqade/qasm2/passes/py2qasm.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from kirin.passes import Pass
55
from kirin.rewrite import Walk, Fixpoint
66
from kirin.dialects import py, math
7-
from kirin.rewrite.abc import RewriteRule
8-
from kirin.rewrite.result import RewriteResult
7+
from kirin.rewrite.abc import RewriteRule, RewriteResult
98

109
from bloqade.qasm2.dialects import core, expr
1110

src/bloqade/qasm2/passes/qasm2py.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@
66
from kirin.passes import Pass
77
from kirin.rewrite import Walk, Fixpoint
88
from kirin.dialects import py, math
9-
from kirin.rewrite.abc import RewriteRule
10-
from kirin.rewrite.result import RewriteResult
9+
from kirin.rewrite.abc import RewriteRule, RewriteResult
1110

1211
from bloqade.qasm2.dialects import core, expr
1312

0 commit comments

Comments
 (0)