Skip to content

Commit 49a3159

Browse files
weinbe58Roger-luo
andauthored
Inline len for IList (#316)
Some kinds of `IList` can't be folded but the length can be determined at compile time. This PR adds a rewrite that inlines the `len` if the length can be determined by type inference. This is preventing some kinds of bloqade programs from being folded to QASM2. The list of qubits can't be folded to a constant value because there is no concrete interpreter for `QRegGet`. ```python @qasm2.extended def _log_ghz(i_layer: int, qubits: ilist.IList[qasm2.Qubit, Any]): stride = len(qubits) // (2**i_layer) if stride == 0: return offset = stride // 2 for j in ilist.range(0, len(qubits), stride): qasm2.cx(ctrl=qubits[j], qarg=qubits[j + offset]) _log_ghz(i_layer + 1, qubits) @qasm2.extended def main(): q = qasm2.qreg(6) _log_ghz(0,[q[0], q[2], q[4], q[1], q[3], q[5]]) fold = QASM2Fold(qasm2.extended) fold(main) main.print() fold(main) ``` The second fold fails because of this. --------- Co-authored-by: Xiu-zhe (Roger) Luo <[email protected]>
1 parent 0b79d67 commit 49a3159

File tree

11 files changed

+201
-47
lines changed

11 files changed

+201
-47
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .list import List2IList as List2IList
22
from .const import ConstList2IList as ConstList2IList
33
from .unroll import Unroll as Unroll
4+
from .hint_len import HintLen as HintLen
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from kirin import ir, types
2+
from kirin.analysis import const
3+
from kirin.dialects import py
4+
from kirin.rewrite.abc import RewriteRule, RewriteResult
5+
from kirin.dialects.ilist.stmts import IListType
6+
7+
8+
class HintLen(RewriteRule):
9+
10+
def _get_collection_len(self, collection: ir.SSAValue):
11+
coll_type = collection.type
12+
13+
if not isinstance(coll_type, types.Generic):
14+
return None
15+
16+
if (
17+
coll_type.is_subseteq(IListType)
18+
and isinstance(coll_type.vars[1], types.Literal)
19+
and isinstance(coll_type.vars[1].data, int)
20+
):
21+
return coll_type.vars[1].data
22+
else:
23+
return None
24+
25+
def rewrite_Statement(self, node: ir.Statement) -> RewriteResult:
26+
if not isinstance(node, py.Len):
27+
return RewriteResult()
28+
29+
if (coll_len := self._get_collection_len(node.value)) is None:
30+
return RewriteResult()
31+
32+
node.result.hints["const"] = const.Value(coll_len)
33+
34+
return RewriteResult(has_done_something=True)

src/kirin/passes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
from kirin.passes.abc import Pass as Pass
22
from kirin.passes.fold import Fold as Fold
33
from kirin.passes.typeinfer import TypeInfer as TypeInfer
4+
5+
from .default import Default as Default
6+
from .hint_const import HintConst as HintConst

src/kirin/passes/abc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class Pass(ABC):
2424

2525
name: ClassVar[str]
2626
dialects: DialectGroup
27-
no_raise: bool = field(default=False, kw_only=True)
27+
no_raise: bool = field(default=True, kw_only=True)
2828

2929
def __call__(self, mt: Method) -> RewriteResult:
3030
result = self.unsafe_run(mt)
@@ -38,7 +38,7 @@ def fixpoint(self, mt: Method, max_iter: int = 32) -> RewriteResult:
3838
result = result_.join(result)
3939
if not result.has_done_something:
4040
break
41-
mt.code.verify()
41+
mt.verify()
4242
return result
4343

4444
@abstractmethod

src/kirin/passes/aggressive/fold.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,28 @@
66
Chain,
77
Inline,
88
Fixpoint,
9-
WrapConst,
109
Call2Invoke,
1110
ConstantFold,
1211
CFGCompactify,
1312
InlineGetItem,
1413
InlineGetField,
1514
DeadCodeElimination,
1615
)
17-
from kirin.analysis import const
1816
from kirin.ir.method import Method
1917
from kirin.rewrite.abc import RewriteResult
18+
from kirin.passes.hint_const import HintConst
2019

2120

2221
@dataclass
2322
class Fold(Pass):
24-
constprop: const.Propagate = field(init=False)
23+
hint_const: HintConst = field(init=False)
2524

2625
def __post_init__(self):
27-
self.constprop = const.Propagate(self.dialects)
26+
self.hint_const = HintConst(self.dialects)
27+
self.hint_const.no_raise = self.no_raise
2828

2929
def unsafe_run(self, mt: Method) -> RewriteResult:
30-
result = RewriteResult()
31-
frame, _ = self.constprop.run_analysis(mt, no_raise=self.no_raise)
32-
result = Walk(WrapConst(frame)).rewrite(mt.code).join(result)
30+
result = self.hint_const.unsafe_run(mt)
3331
rule = Chain(
3432
ConstantFold(),
3533
Call2Invoke(),

src/kirin/passes/default.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from dataclasses import field, dataclass
2+
3+
from kirin.dialects import ilist
4+
from kirin.ir.method import Method
5+
from kirin.passes.fold import Fold
6+
from kirin.rewrite.abc import RewriteResult
7+
from kirin.passes.aggressive import Fold as AggressiveFold
8+
9+
from .abc import Pass
10+
from .typeinfer import TypeInfer
11+
from .hint_const import HintConst
12+
13+
14+
@dataclass
15+
class Default(Pass):
16+
verify: bool = field(default=False, kw_only=True)
17+
fold: bool = field(default=True, kw_only=True)
18+
aggressive: bool = field(default=False, kw_only=True)
19+
typeinfer: bool = field(default=True, kw_only=True)
20+
21+
hint_const_pass: HintConst = field(init=False)
22+
typeinfer_pass: TypeInfer = field(init=False)
23+
ilist_desugar: ilist.IListDesugar = field(init=False)
24+
fold_pass: Pass = field(init=False)
25+
26+
def __post_init__(self):
27+
# TODO: cleanup no_raise
28+
self.ilist_desugar = ilist.IListDesugar(self.dialects, no_raise=self.no_raise)
29+
self.typeinfer_pass = TypeInfer(self.dialects, no_raise=self.no_raise)
30+
self.hint_const_pass = HintConst(self.dialects, no_raise=self.no_raise)
31+
if self.aggressive:
32+
self.fold_pass = AggressiveFold(self.dialects, no_raise=self.no_raise)
33+
else:
34+
self.fold_pass = Fold(self.dialects, no_raise=self.no_raise)
35+
36+
def unsafe_run(self, mt: Method) -> RewriteResult:
37+
if self.verify:
38+
mt.verify()
39+
40+
result = self.ilist_desugar.fixpoint(mt)
41+
if self.typeinfer:
42+
self.typeinfer_pass(mt).join(result)
43+
44+
if self.fold:
45+
if self.aggressive:
46+
self.fold_pass.fixpoint(mt).join(result)
47+
else:
48+
self.fold_pass(mt).join(result)
49+
return result

src/kirin/passes/fold.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,32 @@
1-
from dataclasses import dataclass
1+
from dataclasses import field, dataclass
22

33
from kirin.ir import Method, SSACFGRegion
44
from kirin.rewrite import (
55
Walk,
66
Chain,
77
Fixpoint,
8-
WrapConst,
98
Call2Invoke,
109
ConstantFold,
1110
CFGCompactify,
1211
InlineGetItem,
1312
DeadCodeElimination,
1413
)
15-
from kirin.analysis import const
1614
from kirin.passes.abc import Pass
1715
from kirin.rewrite.abc import RewriteResult
1816

17+
from .hint_const import HintConst
18+
1919

2020
@dataclass
2121
class Fold(Pass):
22+
hint_const: HintConst = field(init=False)
23+
24+
def __post_init__(self):
25+
self.hint_const = HintConst(self.dialects)
26+
self.hint_const.no_raise = self.no_raise
2227

2328
def unsafe_run(self, mt: Method) -> RewriteResult:
24-
constprop = const.Propagate(self.dialects)
25-
frame, _ = constprop.run_analysis(mt, no_raise=self.no_raise)
26-
result = Walk(WrapConst(frame)).rewrite(mt.code)
29+
result = self.hint_const.unsafe_run(mt)
2730
result = (
2831
Fixpoint(
2932
Walk(

src/kirin/passes/hint_const.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from dataclasses import dataclass
2+
3+
from kirin.ir import Method
4+
from kirin.rewrite import Walk, WrapConst
5+
from kirin.analysis import const
6+
from kirin.passes.abc import Pass
7+
from kirin.rewrite.abc import RewriteResult
8+
9+
10+
@dataclass
11+
class HintConst(Pass):
12+
13+
def unsafe_run(self, mt: Method) -> RewriteResult:
14+
constprop = const.Propagate(self.dialects)
15+
frame, _ = constprop.run_analysis(mt, no_raise=self.no_raise)
16+
return Walk(WrapConst(frame)).rewrite(mt.code)

src/kirin/passes/typeinfer.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from dataclasses import dataclass
1+
from dataclasses import field, dataclass
22

33
from kirin.ir import Method, HasSignature
44
from kirin.rewrite import Walk, Chain
@@ -8,23 +8,35 @@
88
from kirin.analysis.typeinfer import TypeInference
99
from kirin.rewrite.apply_type import ApplyType
1010
from kirin.rewrite.type_assert import InlineTypeAssert
11+
from kirin.dialects.ilist.rewrite import HintLen
12+
13+
from .hint_const import HintConst
1114

1215

1316
@dataclass
1417
class TypeInfer(Pass):
18+
hint_const: HintConst = field(init=False)
1519

1620
def __post_init__(self):
1721
self.infer = TypeInference(self.dialects)
22+
self.hint_const = HintConst(self.dialects)
23+
self.hint_const.no_raise = self.no_raise
1824

1925
def unsafe_run(self, mt: Method) -> RewriteResult:
26+
result = self.hint_const.unsafe_run(mt)
2027
frame, return_type = self.infer.run_analysis(
2128
mt, mt.arg_types, no_raise=self.no_raise
2229
)
2330
if trait := mt.code.get_trait(HasSignature):
2431
trait.set_signature(mt.code, Signature(mt.arg_types, return_type))
2532

26-
result = Chain(
27-
Walk(ApplyType(frame.entries)), Walk(InlineTypeAssert())
28-
).rewrite(mt.code)
33+
result = (
34+
Chain(
35+
Walk(ApplyType(frame.entries)),
36+
Walk(Chain(InlineTypeAssert(), HintLen())),
37+
)
38+
.rewrite(mt.code)
39+
.join(result)
40+
)
2941
mt.inferred = True
3042
return result

src/kirin/prelude.py

Lines changed: 37 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from typing_extensions import Doc, Annotated
77

88
from kirin.ir import Method, dialect_group
9-
from kirin.passes import aggressive
9+
from kirin.passes import Default
1010
from kirin.dialects import cf, scf, func, math, ilist, lowering
1111
from kirin.dialects.py import (
1212
cmp,
@@ -26,8 +26,6 @@
2626
iterable,
2727
assertion,
2828
)
29-
from kirin.passes.fold import Fold
30-
from kirin.passes.typeinfer import TypeInfer
3129

3230

3331
@dialect_group(
@@ -142,10 +140,6 @@ def main(x: int) -> int:
142140
main.print() # main is a Method!
143141
```
144142
"""
145-
fold_pass = Fold(self)
146-
ilist_desugar = ilist.IListDesugar(self)
147-
aggressive_fold_pass = aggressive.Fold(self)
148-
typeinfer_pass = TypeInfer(self)
149143

150144
def run_pass(
151145
mt: Annotated[Method, Doc("The method to run pass on.")],
@@ -163,20 +157,17 @@ def run_pass(
163157
aggressive: Annotated[
164158
bool, Doc("run aggressive folding passes if `fold=True`")
165159
] = False,
160+
no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True,
166161
) -> None:
167-
if verify:
168-
mt.verify()
169-
170-
ilist_desugar.fixpoint(mt)
171-
172-
if fold:
173-
if aggressive:
174-
aggressive_fold_pass.fixpoint(mt)
175-
else:
176-
fold_pass(mt)
177-
178-
if typeinfer:
179-
typeinfer_pass(mt)
162+
default_pass = Default(
163+
self,
164+
verify=verify,
165+
fold=fold,
166+
aggressive=aggressive,
167+
typeinfer=typeinfer,
168+
no_raise=no_raise,
169+
)
170+
default_pass.fixpoint(mt)
180171

181172
return run_pass
182173

@@ -221,16 +212,34 @@ def run_pass(method: Method) -> None:
221212
)
222213
)
223214
def structural(self):
224-
"""Structural kernel without optimization passes."""
225-
typeinfer_pass = TypeInfer(self)
215+
"""Structural kernel with optimization passes."""
226216

227217
def run_pass(
228-
method: Method, *, verify: bool = True, typeinfer: bool = True
218+
mt: Annotated[Method, Doc("The method to run pass on.")],
219+
*,
220+
verify: Annotated[
221+
bool, Doc("run `verify` before running passes, default is `True`")
222+
] = True,
223+
typeinfer: Annotated[
224+
bool,
225+
Doc(
226+
"run type inference and apply the inferred type to IR, default `False`"
227+
),
228+
] = False,
229+
fold: Annotated[bool, Doc("run folding passes")] = True,
230+
aggressive: Annotated[
231+
bool, Doc("run aggressive folding passes if `fold=True`")
232+
] = False,
233+
no_raise: Annotated[bool, Doc("do not raise exception during analysis")] = True,
229234
) -> None:
230-
if verify:
231-
method.verify()
232-
233-
if typeinfer:
234-
typeinfer_pass(method)
235+
default_pass = Default(
236+
self,
237+
verify=verify,
238+
fold=fold,
239+
aggressive=aggressive,
240+
typeinfer=typeinfer,
241+
no_raise=no_raise,
242+
)
243+
default_pass.fixpoint(mt)
235244

236245
return run_pass

0 commit comments

Comments
 (0)