Skip to content

Commit ca2d8c7

Browse files
weinbe58david-pl
authored andcommitted
Refactor QASM2Fold and CanonicalizeIList (#536)
Just doing some small housekeeping using the new passes/rewrite rules added in Kirin in the past few months.
1 parent 5badff3 commit ca2d8c7

File tree

7 files changed

+153
-169
lines changed

7 files changed

+153
-169
lines changed

src/bloqade/qasm2/passes/fold.py

Lines changed: 14 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,12 @@
11
from dataclasses import field, dataclass
22

33
from kirin import ir
4-
from kirin.passes import Pass, TypeInfer
5-
from kirin.rewrite import (
6-
Walk,
7-
Chain,
8-
Inline,
9-
Fixpoint,
10-
WrapConst,
11-
Call2Invoke,
12-
ConstantFold,
13-
CFGCompactify,
14-
InlineGetItem,
15-
InlineGetField,
16-
DeadCodeElimination,
17-
CommonSubexpressionElimination,
18-
)
19-
from kirin.analysis import const
20-
from kirin.dialects import scf, ilist
4+
from kirin.passes import Pass
215
from kirin.ir.method import Method
226
from kirin.rewrite.abc import RewriteResult
237

248
from bloqade.qasm2.dialects import expr
9+
from bloqade.rewrite.passes import AggressiveUnroll
2510

2611
from .unroll_if import UnrollIfs
2712

@@ -30,71 +15,27 @@
3015
class QASM2Fold(Pass):
3116
"""Fold pass for qasm2.extended"""
3217

33-
constprop: const.Propagate = field(init=False)
3418
inline_gate_subroutine: bool = True
3519
unroll_ifs: bool = True
20+
aggressive_unroll: AggressiveUnroll = field(init=False)
3621

3722
def __post_init__(self):
38-
self.constprop = const.Propagate(self.dialects)
39-
self.typeinfer = TypeInfer(self.dialects)
23+
def inline_simple(node: ir.Statement):
24+
if isinstance(node, expr.GateFunction):
25+
return self.inline_gate_subroutine
4026

41-
def unsafe_run(self, mt: Method) -> RewriteResult:
42-
result = RewriteResult()
43-
frame, _ = self.constprop.run_analysis(mt)
44-
result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
45-
rule = Chain(
46-
ConstantFold(),
47-
Call2Invoke(),
48-
InlineGetField(),
49-
InlineGetItem(),
50-
DeadCodeElimination(),
51-
CommonSubexpressionElimination(),
52-
)
53-
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
27+
return True
5428

55-
result = (
56-
Walk(
57-
Chain(
58-
scf.unroll.PickIfElse(),
59-
scf.unroll.ForLoop(),
60-
scf.trim.UnusedYield(),
61-
)
62-
)
63-
.rewrite(mt.code)
64-
.join(result)
29+
self.aggressive_unroll = AggressiveUnroll(
30+
self.dialects, inline_simple, no_raise=self.no_raise
6531
)
6632

67-
if self.unroll_ifs:
68-
UnrollIfs(mt.dialects).unsafe_run(mt).join(result)
69-
70-
# run typeinfer again after unroll etc. because we now insert
71-
# a lot of new nodes, which might have more precise types
72-
self.typeinfer.unsafe_run(mt)
73-
result = (
74-
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
75-
.rewrite(mt.code)
76-
.join(result)
77-
)
78-
79-
def inline_simple(node: ir.Statement):
80-
if isinstance(node, expr.GateFunction):
81-
return self.inline_gate_subroutine
33+
def unsafe_run(self, mt: Method) -> RewriteResult:
34+
result = RewriteResult()
8235

83-
if not isinstance(node.parent_stmt, (scf.For, scf.IfElse)):
84-
return True # always inline calls outside of loops and if-else
36+
if self.unroll_ifs:
37+
result = UnrollIfs(mt.dialects).unsafe_run(mt).join(result)
8538

86-
# inside loops and if-else, only inline simple functions, i.e. functions with a single block
87-
if (trait := node.get_trait(ir.CallableStmtInterface)) is None:
88-
return False # not a callable, don't inline to be safe
89-
region = trait.get_callable_region(node)
90-
return len(region.blocks) == 1
39+
result = self.aggressive_unroll.unsafe_run(mt).join(result)
9140

92-
result = (
93-
Walk(
94-
Inline(inline_simple),
95-
)
96-
.rewrite(mt.code)
97-
.join(result)
98-
)
99-
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
10041
return result
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
1+
from .aggressive_unroll import AggressiveUnroll as AggressiveUnroll
12
from .canonicalize_ilist import CanonicalizeIList as CanonicalizeIList
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Callable
2+
from dataclasses import field, dataclass
3+
4+
from kirin import ir
5+
from kirin.passes import Pass, HintConst, TypeInfer
6+
from kirin.rewrite import (
7+
Walk,
8+
Chain,
9+
Inline,
10+
Fixpoint,
11+
Call2Invoke,
12+
ConstantFold,
13+
CFGCompactify,
14+
InlineGetItem,
15+
InlineGetField,
16+
DeadCodeElimination,
17+
)
18+
from kirin.dialects import scf, ilist
19+
from kirin.ir.method import Method
20+
from kirin.rewrite.abc import RewriteResult
21+
from kirin.rewrite.cse import CommonSubexpressionElimination
22+
from kirin.passes.aggressive import UnrollScf
23+
24+
25+
@dataclass
26+
class Fold(Pass):
27+
hint_const: HintConst = field(init=False)
28+
29+
def __post_init__(self):
30+
self.hint_const = HintConst(self.dialects, no_raise=self.no_raise)
31+
32+
def unsafe_run(self, mt: Method) -> RewriteResult:
33+
result = RewriteResult()
34+
result = self.hint_const.unsafe_run(mt).join(result)
35+
rule = Chain(
36+
ConstantFold(),
37+
Call2Invoke(),
38+
InlineGetField(),
39+
InlineGetItem(),
40+
ilist.rewrite.InlineGetItem(),
41+
ilist.rewrite.HintLen(),
42+
)
43+
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
44+
45+
return result
46+
47+
48+
@dataclass
49+
class AggressiveUnroll(Pass):
50+
"""A pass to unroll structured control flow"""
51+
52+
additional_inline_heuristic: Callable[[ir.Statement], bool] = lambda node: True
53+
54+
fold: Fold = field(init=False)
55+
typeinfer: TypeInfer = field(init=False)
56+
scf_unroll: UnrollScf = field(init=False)
57+
58+
def __post_init__(self):
59+
self.fold = Fold(self.dialects, no_raise=self.no_raise)
60+
self.typeinfer = TypeInfer(self.dialects, no_raise=self.no_raise)
61+
self.scf_unroll = UnrollScf(self.dialects, no_raise=self.no_raise)
62+
63+
def unsafe_run(self, mt: Method) -> RewriteResult:
64+
result = RewriteResult()
65+
result = self.scf_unroll.unsafe_run(mt).join(result)
66+
result = (
67+
Walk(Chain(ilist.rewrite.ConstList2IList(), ilist.rewrite.Unroll()))
68+
.rewrite(mt.code)
69+
.join(result)
70+
)
71+
result = self.typeinfer.unsafe_run(mt).join(result)
72+
result = self.fold.unsafe_run(mt).join(result)
73+
result = Walk(Inline(self.inline_heuristic)).rewrite(mt.code).join(result)
74+
result = Walk(Fixpoint(CFGCompactify())).rewrite(mt.code).join(result)
75+
76+
rule = Chain(
77+
CommonSubexpressionElimination(),
78+
DeadCodeElimination(),
79+
)
80+
result = Fixpoint(Walk(rule)).rewrite(mt.code).join(result)
81+
82+
return result
83+
84+
def inline_heuristic(self, node: ir.Statement) -> bool:
85+
"""The heuristic to decide whether to inline a function call or not.
86+
inside loops and if-else, only inline simple functions, i.e.
87+
functions with a single block
88+
"""
89+
return not isinstance(
90+
node.parent_stmt, (scf.For, scf.IfElse)
91+
) and self.additional_inline_heuristic(
92+
node
93+
) # always inline calls outside of loops and if-else
Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,32 @@
1-
from dataclasses import dataclass
1+
from dataclasses import field, dataclass
22

3-
from kirin import ir
4-
from kirin.passes import Pass
3+
from kirin import ir, passes
54
from kirin.rewrite import (
65
Walk,
76
Chain,
87
Fixpoint,
98
)
10-
from kirin.analysis import const
11-
12-
from ..rules.flatten_ilist import FlattenAddOpIList
13-
from ..rules.inline_getitem_ilist import InlineGetItemFromIList
9+
from kirin.dialects.ilist import rewrite
1410

1511

1612
@dataclass
17-
class CanonicalizeIList(Pass):
13+
class CanonicalizeIList(passes.Pass):
1814

19-
def unsafe_run(self, mt: ir.Method):
15+
fold_pass: passes.Fold = field(init=False)
2016

21-
cp_result_frame, _ = const.Propagate(dialects=mt.dialects).run_analysis(mt)
17+
def __post_init__(self):
18+
self.fold_pass = passes.Fold(self.dialects, no_raise=self.no_raise)
2219

23-
return Fixpoint(
24-
Chain(
25-
Walk(InlineGetItemFromIList(constprop_result=cp_result_frame.entries)),
26-
Walk(FlattenAddOpIList()),
20+
def unsafe_run(self, mt: ir.Method):
21+
result = Fixpoint(
22+
Walk(
23+
Chain(
24+
rewrite.InlineGetItem(),
25+
rewrite.FlattenAdd(),
26+
rewrite.HintLen(),
27+
)
2728
)
2829
).rewrite(mt.code)
30+
31+
result = self.fold_pass(mt).join(result)
32+
return result

src/bloqade/rewrite/rules/flatten_ilist.py

Lines changed: 0 additions & 51 deletions
This file was deleted.

src/bloqade/rewrite/rules/inline_getitem_ilist.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

test/qasm2/emit/test_qasm2_emit.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,3 +288,30 @@ def nested_kernel():
288288

289289
target = qasm2.emit.QASM2()
290290
target.emit(nested_kernel)
291+
292+
293+
def test_loop_unroll():
294+
n_qubits = 4
295+
296+
@qasm2.extended
297+
def ghz_linear():
298+
q = qasm2.qreg(n_qubits)
299+
qasm2.h(q[0])
300+
for i in range(1, n_qubits):
301+
qasm2.cx(q[i - 1], q[i])
302+
303+
target = qasm2.emit.QASM2(
304+
allow_parallel=True,
305+
)
306+
qasm2_str = target.emit_str(ghz_linear)
307+
308+
assert qasm2_str == (
309+
"""KIRIN {func,lowering.call,lowering.func,py.ilist,qasm2.core,qasm2.expr,qasm2.indexing,qasm2.noise,qasm2.parallel,qasm2.uop,scf};
310+
include "qelib1.inc";
311+
qreg q[4];
312+
h q[0];
313+
CX q[0], q[1];
314+
CX q[1], q[2];
315+
CX q[2], q[3];
316+
"""
317+
)

0 commit comments

Comments
 (0)