Skip to content

Commit cd53e5d

Browse files
committed
use callgraph pass instead of nesting rewrite rule applications
1 parent 6b9e52f commit cd53e5d

File tree

4 files changed

+61
-39
lines changed

4 files changed

+61
-39
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from kirin import ir, passes
2+
from kirin.rewrite import Walk, Chain
3+
from kirin.dialects import func
4+
from kirin.rewrite.abc import RewriteRule, RewriteResult
5+
6+
from bloqade.rewrite.passes import CallGraphPass
7+
8+
from ..rewrite import qasm2 as qasm2_rule
9+
10+
11+
class QASM2GateFuncToKirinFunc(RewriteRule):
12+
13+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
14+
from bloqade.qasm2.dialects.expr.stmts import GateFunction
15+
16+
if not isinstance(node, GateFunction):
17+
return RewriteResult()
18+
19+
kirin_func = func.Function(
20+
sym_name=node.sym_name,
21+
signature=node.signature,
22+
body=node.body,
23+
)
24+
node.replace_by(kirin_func)
25+
26+
return RewriteResult(has_done_something=True)
27+
28+
29+
class QASM2GateFuncToSquinPass(passes.Pass):
30+
31+
def unsafe_run(self, mt: ir.Method) -> RewriteResult:
32+
convert_to_kirin_func = CallGraphPass(
33+
dialects=mt.dialects, rule=Walk(QASM2GateFuncToKirinFunc())
34+
)
35+
rewrite_result = convert_to_kirin_func(mt)
36+
37+
combined_qasm2_rules = Walk(
38+
Chain(
39+
qasm2_rule.QASM2ExprToSquin(),
40+
qasm2_rule.QASM2CoreToSquin(),
41+
qasm2_rule.QASM2UOPToSquin(),
42+
qasm2_rule.QASM2NoiseToSquin(),
43+
qasm2_rule.QASM2GlobParallelToSquin(),
44+
)
45+
)
46+
47+
body_conversion_pass = CallGraphPass(
48+
dialects=mt.dialects, rule=combined_qasm2_rules
49+
)
50+
rewrite_result = body_conversion_pass(mt).join(rewrite_result)
51+
52+
return rewrite_result

src/bloqade/squin/passes/qasm2_to_squin.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
QASM2UOPToSquin,
1212
QASM2CoreToSquin,
1313
QASM2ExprToSquin,
14-
QASM2FuncToSquin,
1514
QASM2NoiseToSquin,
1615
QASM2GlobParallelToSquin,
1716
)
1817

18+
from .qasm2_gate_func_to_squin import QASM2GateFuncToSquinPass
19+
1920

2021
@dataclass
2122
class QASM2ToSquin(Pass):
@@ -25,7 +26,6 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
2526
# rewrite all QASM2 to squin first
2627
rewrite_result = Walk(
2728
Chain(
28-
QASM2FuncToSquin(),
2929
QASM2ExprToSquin(),
3030
QASM2CoreToSquin(),
3131
QASM2UOPToSquin(),
@@ -34,6 +34,13 @@ def unsafe_run(self, mt: ir.Method) -> RewriteResult:
3434
)
3535
).rewrite(mt.code)
3636

37+
# go into subkernels
38+
rewrite_result = (
39+
QASM2GateFuncToSquinPass(dialects=mt.dialects)
40+
.unsafe_run(mt)
41+
.join(rewrite_result)
42+
)
43+
3744
# kernel should be entirely in squin dialect now
3845
mt.dialects = squin.kernel
3946

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .uop_to_squin import QASM2UOPToSquin as QASM2UOPToSquin
22
from .core_to_squin import QASM2CoreToSquin as QASM2CoreToSquin
33
from .expr_to_squin import QASM2ExprToSquin as QASM2ExprToSquin
4-
from .func_to_squin import QASM2FuncToSquin as QASM2FuncToSquin
54
from .noise_to_squin import QASM2NoiseToSquin as QASM2NoiseToSquin
65
from .glob_parallel_to_squin import QASM2GlobParallelToSquin as QASM2GlobParallelToSquin

src/bloqade/squin/rewrite/qasm2/func_to_squin.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

0 commit comments

Comments
 (0)