Skip to content

Commit f719541

Browse files
authored
wrap constants after const propagation (#137)
this PR insert a rewrite after const folding, this new rule will insert the inference result of constant propagation into the IR type annotation if the value is a constant or partial constant.
1 parent db071a7 commit f719541

File tree

17 files changed

+128
-60
lines changed

17 files changed

+128
-60
lines changed

src/kirin/analysis/const/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
NotPure as NotPure,
88
Unknown as Unknown,
99
JointResult as JointResult,
10+
PartialConst as PartialConst,
1011
PartialTuple as PartialTuple,
1112
PurityBottom as PurityBottom,
1213
PartialLambda as PartialLambda,

src/kirin/analysis/const/lattice.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@ def is_equal(self, other: Result) -> bool:
6262
return False
6363

6464

65+
@dataclass
66+
class PartialConst(Result):
67+
pass
68+
69+
6570
@final
6671
class PartialTupleMeta(LatticeMeta):
6772
def __call__(cls, data: tuple[Result, ...]):
@@ -72,7 +77,7 @@ def __call__(cls, data: tuple[Result, ...]):
7277

7378
@final
7479
@dataclass
75-
class PartialTuple(Result, metaclass=PartialTupleMeta):
80+
class PartialTuple(PartialConst, metaclass=PartialTupleMeta):
7681
data: tuple[Result, ...]
7782

7883
def join(self, other: Result) -> Result:
@@ -119,7 +124,7 @@ def is_subseteq_Value(self, other: Value) -> bool:
119124

120125
@final
121126
@dataclass
122-
class PartialLambda(Result):
127+
class PartialLambda(PartialConst):
123128
argnames: list[str]
124129
code: ir.Statement
125130
captured: tuple[Result, ...]

src/kirin/analysis/typeinfer/analysis.py

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
from typing import TypeGuard
2+
13
from kirin import ir, types, interp
4+
from kirin.analysis import const
25
from kirin.interp.impl import Signature
36
from kirin.analysis.forward import Forward, ForwardFrame
47

@@ -17,14 +20,16 @@ def build_signature(
1720
) -> Signature:
1821
_args = ()
1922
for x in frame.get_values(stmt.args):
20-
if isinstance(x, types.Hinted):
21-
_args += (x.type,)
22-
elif isinstance(x, types.Generic):
23-
_args += (x.body,)
24-
else:
25-
_args += (x,)
23+
_args += (self._unwrap(x),)
2624
return Signature(stmt.__class__, _args)
2725

26+
def _unwrap(self, value: types.TypeAttribute) -> types.TypeAttribute:
27+
if isinstance(value, types.Hinted):
28+
return self._unwrap(value.type)
29+
elif isinstance(value, types.Generic):
30+
return value.body
31+
return value
32+
2833
def eval_stmt(
2934
self, frame: ForwardFrame[types.TypeAttribute, None], stmt: ir.Statement
3035
) -> interp.StatementResult[types.TypeAttribute]:
@@ -46,3 +51,19 @@ def run_method(
4651
method.code, (types.Hinted(types.PyClass(ir.Method), method),) + args
4752
)
4853
return types.Bottom
54+
55+
def is_const(
56+
self, value: types.TypeAttribute
57+
) -> TypeGuard[types.Hinted[const.Value]]:
58+
return isinstance(value, types.Hinted) and isinstance(value.data, const.Value)
59+
60+
def is_partial_const(
61+
self, value: types.TypeAttribute
62+
) -> TypeGuard[
63+
types.Hinted[const.Value]
64+
| types.Hinted[const.PartialTuple]
65+
| types.Hinted[const.PartialLambda]
66+
]:
67+
return isinstance(value, types.Hinted) and isinstance(
68+
value.data, (const.Value, const.PartialConst)
69+
)

src/kirin/dialects/fcf/typeinfer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,10 @@ def fold(
3737
stmt: Foldl | Foldr,
3838
values: tuple[ir.types.TypeAttribute, ...],
3939
):
40-
if not isinstance(values[0], ir.types.Hinted):
40+
if not interp.is_const(values[0]):
4141
return (stmt.result.type,) # give up on dynamic calls
4242

43-
fn: ir.Method = values[0].data
43+
fn: ir.Method = values[0].data.data
4444
coll: ir.types.TypeAttribute = values[1]
4545
init: ir.types.TypeAttribute = values[2]
4646

@@ -71,10 +71,10 @@ def map_list(
7171
stmt: Map,
7272
):
7373
fn_value = frame.get(stmt.fn)
74-
if not isinstance(fn_value, ir.types.Hinted):
74+
if not interp.is_const(fn_value):
7575
return (ir.types.List[ir.types.Any],) # give up on dynamic calls
7676

77-
fn: ir.Method = fn_value.data
77+
fn: ir.Method = fn_value.data.data
7878
coll: ir.types.TypeAttribute = frame.get(stmt.coll)
7979
if isinstance(coll, ir.types.Generic) and coll.is_subseteq(ir.types.List):
8080
elem = interp.eval(fn, (coll.vars[0],)).value
@@ -92,10 +92,10 @@ def map_range(
9292
stmt: Map,
9393
):
9494
fn_value = frame.get(stmt.fn)
95-
if not isinstance(fn_value, ir.types.Hinted):
95+
if not interp.is_const(fn_value):
9696
return (ir.types.List,) # give up on dynamic calls
9797

98-
fn: ir.Method = fn_value.data
98+
fn: ir.Method = fn_value.data.data
9999
elem = interp.eval(fn, (ir.types.Int,)).value
100100
# fn errors forward the error
101101
if isinstance(elem, Err):
@@ -114,10 +114,10 @@ def scan(
114114
init = frame.get(stmt.init)
115115
coll = frame.get(stmt.coll)
116116

117-
if not isinstance(fn_value, ir.types.Hinted):
117+
if not interp.is_const(fn_value):
118118
return (ir.types.Tuple[init, ir.types.List[ir.types.Any]],)
119119

120-
fn: ir.Method = fn_value.data
120+
fn: ir.Method = fn_value.data.data
121121
if isinstance(coll, ir.types.Generic) and coll.is_subseteq(ir.types.List):
122122
ret = interp.eval(fn, (init, coll.vars[0])).value
123123
if isinstance(ret, Err):

src/kirin/dialects/func/typeinfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ def return_(self, interp: TypeInference, frame: Frame, stmt: Return) -> ReturnVa
3030
def call(self, interp: TypeInference, frame: Frame, stmt: Call):
3131
# give up on dynamic method calls
3232
callee = frame.get(stmt.callee)
33-
if not isinstance(callee, ir.types.Hinted):
33+
if not interp.is_const(callee):
3434
return (stmt.result.type,)
3535

36-
mt: ir.Method = callee.data
36+
mt: ir.Method = callee.data.data
3737
return self._invoke_method(
3838
interp,
3939
mt,

src/kirin/dialects/py/stmts/_stmts/constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, value: T | data.PyAttr[T]) -> None:
2323
value = data.PyAttr(value)
2424
super().__init__(
2525
properties={"value": value},
26-
result_types=(types.Hinted(value.type, value.data),),
26+
result_types=(value.type,),
2727
)
2828

2929
def print_impl(self, printer: Printer) -> None:

src/kirin/dialects/py/stmts/typeinfer.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from kirin.ir import types
22
from kirin.interp import Frame, MethodTable, StatementResult, impl
3+
from kirin.analysis.typeinfer import TypeInference
34

45
from . import _stmts as py
56
from .dialect import dialect
@@ -175,12 +176,12 @@ def mat_mult(self, interp, frame, stmt) -> StatementResult[types.TypeAttribute]:
175176
@impl(py.GetItem)
176177
def getitem(
177178
self,
178-
interp,
179+
interp: TypeInference,
179180
frame: Frame[types.TypeAttribute],
180181
stmt: py.GetItem,
181182
) -> StatementResult[types.TypeAttribute]:
182183
obj = frame.get(stmt.obj)
183-
if isinstance(obj, types.Hinted): # unwrap const
184+
if interp.is_const(obj): # unwrap const
184185
obj = obj.type
185186
index: types.TypeAttribute = frame.get(stmt.index)
186187
# TODO: replace this when we can multiple dispatch
@@ -221,30 +222,31 @@ def getitem_tuple(
221222

222223
def getitem_tuple_index(
223224
self,
224-
interp,
225+
interp: TypeInference,
225226
stmt: py.GetItem,
226227
obj: types.Generic,
227228
index: types.TypeAttribute,
228229
):
229-
if isinstance(index, types.Hinted): # const
230-
if obj.vararg and index.data >= len(obj.vars):
230+
if interp.is_const(index) and index.type.is_subseteq(types.Int):
231+
index_: int = index.data.data
232+
if obj.vararg and index_ >= len(obj.vars):
231233
return (obj.vararg.typ,)
232-
elif index.data < len(obj.vars):
233-
return (obj.vars[index.data],)
234+
elif index_ < len(obj.vars):
235+
return (obj.vars[index_],)
234236
else:
235237
return (types.Bottom,)
236238
else:
237239
return (self.getitem_tuple_union(obj),)
238240

239241
def getitem_tuple_slice(
240242
self,
241-
interp,
243+
interp: TypeInference,
242244
stmt: py.GetItem,
243245
obj: types.Generic,
244246
index: types.TypeAttribute,
245247
):
246-
if isinstance(index, types.Hinted):
247-
data: slice = index.data
248+
if interp.is_const(index):
249+
data: slice = index.data.data
248250
if obj.vararg and data.stop >= len(obj.vars):
249251
return (
250252
types.Union(
@@ -315,17 +317,15 @@ def new_list(
315317

316318
@impl(py.Slice)
317319
def slice(
318-
self, interp, frame: Frame[types.TypeAttribute], stmt: py.Slice
320+
self, interp: TypeInference, frame: Frame[types.TypeAttribute], stmt: py.Slice
319321
) -> StatementResult[types.TypeAttribute]:
320322
start, stop, step = frame.get_values(stmt.args)
321-
if (
322-
isinstance(start, types.Hinted)
323-
and isinstance(stop, types.Hinted)
324-
and isinstance(step, types.Hinted)
325-
and isinstance(stmt.result.type, types.TypeAttribute)
326-
):
323+
if interp.is_const(start) and interp.is_const(stop) and interp.is_const(step):
327324
return (
328-
types.Hinted(stmt.result.type, slice(start.data, stop.data, step.data)),
325+
types.Hinted(
326+
stmt.result.type,
327+
slice(start.data.data, stop.data.data, step.data.data),
328+
),
329329
)
330330

331331
return (stmt.result.type,)

src/kirin/emit/abc.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class EmitABC(interp.BaseInterpreter[FrameType, ValueType]):
3737
def emit_type_TypeVar(self, attr: ir.types.TypeVar) -> ValueType: ...
3838
def emit_type_Vararg(self, attr: ir.types.Vararg) -> ValueType: ...
3939
def emit_type_Generic(self, attr: ir.types.Generic) -> ValueType: ...
40-
def emit_type_Const(self, attr: ir.types.Hinted) -> ValueType: ...
40+
def emit_type_Hinted(self, attr: ir.types.Hinted) -> ValueType: ...
4141
def emit_type_PyClass(self, attr: ir.types.PyClass) -> ValueType: ...
4242
def emit_attribute_fallback(self, attr: ir.Attribute) -> ValueType: ...
4343
def emit_stmt_begin(self, frame: FrameType, stmt: ir.Statement) -> None: ...

src/kirin/ir/_types.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class _TypeAttribute(Attribute):
1616
def is_subseteq_Union(self, other: Union) -> bool: ...
1717
def is_subseteq_Literal(self, other: Literal) -> bool: ...
1818
def is_subseteq_TypeVar(self, other: TypeVar) -> bool: ...
19-
def is_subseteq_Const(self, other: Hinted) -> bool: ...
19+
def is_subseteq_Hinted(self, other: Hinted) -> bool: ...
2020
def is_subseteq_PyClass(self, other: PyClass) -> bool: ...
2121
def is_subseteq_Generic(self, other: Generic) -> bool: ...
2222
def is_subseteq_fallback(self, other: TypeAttribute) -> bool: ...

src/kirin/ir/types.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,9 @@ def join(self, other: "TypeAttribute") -> "TypeAttribute":
6868
return Union(self, other)
6969
return AnyType() # don't know how to join
7070

71+
def is_subseteq_Hinted(self, other: "Hinted") -> bool:
72+
return self.is_subseteq(other.type)
73+
7174
def print_impl(self, printer: Printer) -> None:
7275
printer.print_name(self, prefix="!")
7376

@@ -162,9 +165,6 @@ def is_subseteq_Generic(self, other: "Generic") -> bool:
162165
def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
163166
return self.is_subseteq(other.bound)
164167

165-
def is_subseteq_Const(self, other: "Hinted") -> bool:
166-
return self.is_subseteq(other.type)
167-
168168
def __hash__(self) -> int:
169169
return hash((PyClass, self.typ))
170170

0 commit comments

Comments
 (0)