Skip to content

Commit 88db9e9

Browse files
committed
Merge origin/david/571-kirin-upgrade-branch into dl/codegen
2 parents f5d3ba1 + dc9a5bb commit 88db9e9

File tree

18 files changed

+387
-408
lines changed

18 files changed

+387
-408
lines changed

.github/workflows/ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ jobs:
4343
token: ${{ secrets.CODECOV_TOKEN }} # required
4444
- name: Archive code coverage results
4545
if: matrix.python-version == '3.12'
46-
uses: actions/upload-artifact@v4
46+
uses: actions/upload-artifact@v5
4747
with:
4848
name: code-coverage-report
4949
path: coverage.xml
@@ -55,7 +55,7 @@ jobs:
5555
steps:
5656
- uses: actions/checkout@v5
5757
- name: download covearge
58-
uses: actions/download-artifact@v5
58+
uses: actions/download-artifact@v6
5959
with:
6060
name: code-coverage-report
6161
- name: check coverage

.github/workflows/release.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ jobs:
1919
- name: Build distribution 📦
2020
run: uv build
2121
- name: Store the distribution packages
22-
uses: actions/upload-artifact@v4
22+
uses: actions/upload-artifact@v5
2323
with:
2424
name: python-package-distributions
2525
path: dist/
@@ -39,7 +39,7 @@ jobs:
3939

4040
steps:
4141
- name: Download all the dists
42-
uses: actions/download-artifact@v5
42+
uses: actions/download-artifact@v6
4343
with:
4444
name: python-package-distributions
4545
path: dist/
@@ -60,12 +60,12 @@ jobs:
6060

6161
steps:
6262
- name: Download all the dists
63-
uses: actions/download-artifact@v5
63+
uses: actions/download-artifact@v6
6464
with:
6565
name: python-package-distributions
6666
path: dist/
6767
- name: Sign the dists with Sigstore
68-
uses: sigstore/gh-action-sigstore-python@v3.0.1
68+
uses: sigstore/gh-action-sigstore-python@v3.1.0
6969
with:
7070
inputs: >-
7171
./dist/*.tar.gz

src/bloqade/cirq_utils/emit/base.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,9 +172,6 @@ class EmitCirq(EmitABC[EmitCirqFrame, cirq.Circuit]):
172172
dialects: ir.DialectGroup = field(default_factory=_default_kernel)
173173
void = cirq.Circuit()
174174
qubits: Sequence[cirq.Qid] | None = None
175-
_cached_invokes: dict[int, cirq.FrozenCircuit] = field(
176-
init=False, default_factory=dict
177-
)
178175

179176
def initialize(self) -> Self:
180177
return super().initialize()
Lines changed: 4 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from itertools import chain
2-
from dataclasses import field, dataclass
32

4-
from kirin import ir, passes, rewrite
3+
from kirin import ir, rewrite
54
from kirin.dialects import py, func
65
from kirin.rewrite.abc import RewriteRule, RewriteResult
7-
from kirin.passes.callgraph import CallGraphPass, ReplaceMethods
86
from kirin.analysis.callgraph import CallGraph
97

108
from bloqade.native import kernel, broadcast
119
from bloqade.squin.gate import stmts, dialect as gate_dialect
10+
from bloqade.rewrite.passes import CallGraphPass, UpdateDialectsOnCallGraph
1211

1312

1413
class GateRule(RewriteRule):
@@ -46,63 +45,6 @@ def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
4645
return RewriteResult(has_done_something=True)
4746

4847

49-
@dataclass
50-
class UpdateDialectsOnCallGraph(passes.Pass):
51-
"""Update All dialects on the call graph to a new set of dialects given to this pass.
52-
53-
Usage:
54-
pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects)
55-
pass_(some_method)
56-
57-
Note: This pass does not update the dialects of the input method, but copies
58-
all other methods invoked within it before updating their dialects.
59-
60-
"""
61-
62-
fold_pass: passes.Fold = field(init=False)
63-
64-
def __post_init__(self):
65-
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
66-
67-
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
68-
mt_map = {}
69-
70-
cg = CallGraph(mt)
71-
72-
all_methods = set(sum(map(tuple, cg.defs.values()), ()))
73-
for original_mt in all_methods:
74-
if original_mt is mt:
75-
new_mt = original_mt
76-
else:
77-
new_mt = original_mt.similar(self.dialects)
78-
mt_map[original_mt] = new_mt
79-
80-
result = RewriteResult()
81-
82-
for _, new_mt in mt_map.items():
83-
result = (
84-
rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result)
85-
)
86-
self.fold_pass(new_mt)
87-
88-
return result
89-
90-
91-
@dataclass
92-
class SquinToNativePass(passes.Pass):
93-
94-
call_graph_pass: CallGraphPass = field(init=False)
95-
96-
def __post_init__(self):
97-
rule = rewrite.Walk(GateRule())
98-
self.call_graph_pass = CallGraphPass(
99-
self.dialects, rule, no_raise=self.no_raise
100-
)
101-
102-
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
103-
return self.call_graph_pass.unsafe_run(mt)
104-
105-
10648
class SquinToNative:
10749
"""A Target that converts Squin gates to native gates."""
10850

@@ -126,11 +68,10 @@ def emit(self, mt: ir.Method, *, no_raise=True) -> ir.Method:
12668

12769
out = mt.similar(new_dialects)
12870
UpdateDialectsOnCallGraph(new_dialects, no_raise=no_raise)(out)
129-
SquinToNativePass(new_dialects, no_raise=no_raise)(out)
71+
CallGraphPass(new_dialects, rewrite.Walk(GateRule()), no_raise=no_raise)(out)
13072
# verify all kernels in the callgraph
13173
new_callgraph = CallGraph(out)
132-
all_kernels = (ker for kers in new_callgraph.defs.values() for ker in kers)
133-
for ker in all_kernels:
74+
for ker in new_callgraph.edges.keys():
13475
ker.verify()
13576

13677
return out
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,7 @@
1+
from .callgraph import (
2+
CallGraphPass as CallGraphPass,
3+
ReplaceMethods as ReplaceMethods,
4+
UpdateDialectsOnCallGraph as UpdateDialectsOnCallGraph,
5+
)
16
from .aggressive_unroll import AggressiveUnroll as AggressiveUnroll
27
from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
from dataclasses import field, dataclass
2+
3+
from kirin import ir, passes, rewrite
4+
from kirin.analysis import CallGraph
5+
from kirin.rewrite.abc import RewriteRule, RewriteResult
6+
from kirin.dialects.func.stmts import Invoke
7+
8+
9+
@dataclass
10+
class ReplaceMethods(RewriteRule):
11+
new_symbols: dict[ir.Method, ir.Method]
12+
13+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
14+
if (
15+
not isinstance(node, Invoke)
16+
or (new_callee := self.new_symbols.get(node.callee)) is None
17+
):
18+
return RewriteResult()
19+
20+
node.replace_by(
21+
Invoke(
22+
inputs=node.inputs,
23+
callee=new_callee,
24+
purity=node.purity,
25+
)
26+
)
27+
28+
return RewriteResult(has_done_something=True)
29+
30+
31+
@dataclass
32+
class UpdateDialectsOnCallGraph(passes.Pass):
33+
"""Update All dialects on the call graph to a new set of dialects given to this pass.
34+
35+
Usage:
36+
pass_ = UpdateDialectsOnCallGraph(rule=rule, dialects=new_dialects)
37+
pass_(some_method)
38+
39+
Note: This pass does not update the dialects of the input method, but copies
40+
all other methods invoked within it before updating their dialects.
41+
42+
"""
43+
44+
fold_pass: passes.Fold = field(init=False)
45+
46+
def __post_init__(self):
47+
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
48+
49+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
50+
mt_map = {}
51+
52+
cg = CallGraph(mt)
53+
54+
all_methods = set(sum(map(tuple, cg.defs.values()), ()))
55+
for original_mt in all_methods:
56+
if original_mt is mt:
57+
new_mt = original_mt
58+
else:
59+
new_mt = original_mt.similar(self.dialects)
60+
mt_map[original_mt] = new_mt
61+
62+
result = RewriteResult()
63+
64+
for _, new_mt in mt_map.items():
65+
result = (
66+
rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code).join(result)
67+
)
68+
self.fold_pass(new_mt)
69+
70+
return result
71+
72+
73+
@dataclass
74+
class CallGraphPass(passes.Pass):
75+
"""Copy all functions in the call graph and apply a rule to each of them.
76+
77+
78+
Usage:
79+
rule = Walk(SomeRewriteRule())
80+
pass_ = CallGraphPass(rule=rule, dialects=...)
81+
pass_(some_method)
82+
83+
Note: This pass modifies the input method in place, but copies
84+
all methods invoked within it before applying the rule to them.
85+
86+
"""
87+
88+
rule: RewriteRule
89+
"""The rule to apply to each function in the call graph."""
90+
91+
fold_pass: passes.Fold = field(init=False)
92+
93+
def __post_init__(self):
94+
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
95+
96+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
97+
result = RewriteResult()
98+
mt_map = {}
99+
100+
cg = CallGraph(mt)
101+
102+
all_methods = set(cg.edges.keys())
103+
for original_mt in all_methods:
104+
if original_mt is mt:
105+
new_mt = original_mt
106+
else:
107+
new_mt = original_mt.similar()
108+
result = self.rule.rewrite(new_mt.code).join(result)
109+
mt_map[original_mt] = new_mt
110+
111+
if result.has_done_something:
112+
for _, new_mt in mt_map.items():
113+
rewrite.Walk(ReplaceMethods(mt_map)).rewrite(new_mt.code)
114+
self.fold_pass(new_mt)
115+
116+
return result

src/bloqade/squin/rewrite/U3_to_clifford.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,14 @@ class Sdag(ir.Statement):
1515
pass
1616

1717

18+
class SqrtXdag(ir.Statement):
19+
pass
20+
21+
22+
class SqrtYdag(ir.Statement):
23+
pass
24+
25+
1826
# (theta, phi, lam)
1927
U3_HALF_PI_ANGLE_TO_GATES: dict[
2028
tuple[int, int, int], list[type[ir.Statement]] | list[None]
@@ -27,21 +35,21 @@ class Sdag(ir.Statement):
2735
(1, 0, 1): [gate.stmts.S, gate.stmts.SqrtY],
2836
(1, 0, 2): [gate.stmts.H],
2937
(1, 0, 3): [Sdag, gate.stmts.SqrtY],
30-
(1, 1, 0): [gate.stmts.SqrtY, gate.stmts.S],
31-
(1, 1, 1): [gate.stmts.S, gate.stmts.SqrtY, gate.stmts.S],
32-
(1, 1, 2): [gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.S],
33-
(1, 1, 3): [Sdag, gate.stmts.SqrtY, gate.stmts.S],
34-
(1, 2, 0): [gate.stmts.SqrtY, gate.stmts.Z],
35-
(1, 2, 1): [gate.stmts.S, gate.stmts.SqrtY, gate.stmts.Z],
36-
(1, 2, 2): [gate.stmts.Z, gate.stmts.SqrtY, gate.stmts.Z],
37-
(1, 2, 3): [Sdag, gate.stmts.SqrtY, gate.stmts.Z],
38-
(1, 3, 0): [gate.stmts.SqrtY, Sdag],
39-
(1, 3, 1): [gate.stmts.S, gate.stmts.SqrtY, Sdag],
40-
(1, 3, 2): [gate.stmts.Z, gate.stmts.SqrtY, Sdag],
41-
(1, 3, 3): [Sdag, gate.stmts.SqrtY, Sdag],
38+
(1, 1, 0): [gate.stmts.S, SqrtXdag],
39+
(1, 1, 1): [gate.stmts.Z, SqrtXdag],
40+
(1, 1, 2): [Sdag, SqrtXdag],
41+
(1, 1, 3): [SqrtXdag],
42+
(1, 2, 0): [gate.stmts.Z, SqrtYdag],
43+
(1, 2, 1): [Sdag, SqrtYdag],
44+
(1, 2, 2): [SqrtYdag],
45+
(1, 2, 3): [gate.stmts.S, SqrtYdag],
46+
(1, 3, 0): [Sdag, gate.stmts.SqrtX],
47+
(1, 3, 1): [gate.stmts.SqrtX],
48+
(1, 3, 2): [gate.stmts.S, gate.stmts.SqrtX],
49+
(1, 3, 3): [gate.stmts.Z, gate.stmts.SqrtX],
4250
(2, 0, 0): [gate.stmts.Y],
4351
(2, 0, 1): [gate.stmts.S, gate.stmts.Y],
44-
(2, 0, 2): [gate.stmts.Z, gate.stmts.Y],
52+
(2, 0, 2): [gate.stmts.X],
4553
(2, 0, 3): [Sdag, gate.stmts.Y],
4654
}
4755

@@ -106,6 +114,10 @@ def rewrite_U3(self, node: gate.stmts.U3) -> RewriteResult:
106114
for gate_stmt in gates:
107115
if gate_stmt is Sdag:
108116
new_stmt = gate.stmts.S(adjoint=True, qubits=node.qubits)
117+
elif gate_stmt is SqrtXdag:
118+
new_stmt = gate.stmts.SqrtX(adjoint=True, qubits=node.qubits)
119+
elif gate_stmt is SqrtYdag:
120+
new_stmt = gate.stmts.SqrtY(adjoint=True, qubits=node.qubits)
109121
else:
110122
new_stmt = gate_stmt(qubits=node.qubits)
111123
new_stmt.insert_before(node)

src/bloqade/stim/emit/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from . import impls as impls
12
from .stim_str import FuncEmit as FuncEmit, EmitStimMain as EmitStimMain

src/bloqade/stim/emit/impls.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from kirin.emit import EmitStrFrame
2+
from kirin.interp import MethodTable, impl
3+
from kirin.dialects.debug import Info, dialect
4+
5+
from bloqade.stim.emit.stim_str import EmitStimMain
6+
7+
8+
@dialect.register(key="emit.stim")
9+
class EmitStimDebugMethods(MethodTable):
10+
11+
@impl(Info)
12+
def info(self, emit: EmitStimMain, frame: EmitStrFrame, stmt: Info):
13+
14+
msg: str = frame.get(stmt.msg)
15+
emit.writeln(frame, f"# {msg}")
16+
17+
return ()

test/cirq_utils/test_cirq_to_squin.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -402,15 +402,6 @@ def multi_arg(n: int, p: float):
402402
print(circuit)
403403

404404

405-
if __name__ == "__main__":
406-
test_kernel_with_args()
407-
408-
409-
@pytest.mark.xfail
410-
def test_amplitude_damping():
411-
test_circuit(amplitude_damping)
412-
413-
414405
def test_trotter():
415406

416407
# NOTE: stolen from jonathan's tutorial

0 commit comments

Comments
 (0)