Skip to content

Commit 576e902

Browse files
authored
support inference on generics (#129)
1 parent a3f27e7 commit 576e902

File tree

4 files changed

+110
-7
lines changed

4 files changed

+110
-7
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .solve import TypeResolution as TypeResolution
2+
from .analysis import TypeInference as TypeInference
Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from kirin.interp.impl import Signature
33
from kirin.analysis.forward import Forward, ForwardFrame
44

5+
from .solve import TypeResolution
6+
57

68
class TypeInference(Forward[types.TypeAttribute]):
79
keys = ["typeinfer"]
@@ -29,7 +31,11 @@ def eval_stmt(
2931
method = self.lookup_registry(frame, stmt)
3032
if method is not None:
3133
return method(self, frame, stmt)
32-
return tuple(result.type for result in stmt.results)
34+
35+
resolve = TypeResolution()
36+
for arg, value in zip(stmt.args, frame.get_values(stmt.args)):
37+
resolve.solve(arg.type, value)
38+
return tuple(resolve.substitute(result.type) for result in stmt.results)
3339

3440
def run_method(
3541
self, method: ir.Method, args: tuple[types.TypeAttribute, ...]
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from dataclasses import field, dataclass
2+
3+
from kirin.ir import types
4+
5+
6+
@dataclass
7+
class TypeResolutionResult:
8+
pass
9+
10+
11+
@dataclass
12+
class ResolutionOk(TypeResolutionResult):
13+
14+
def __bool__(self):
15+
return True
16+
17+
18+
Ok = ResolutionOk()
19+
20+
21+
@dataclass
22+
class ResolutionError(TypeResolutionResult):
23+
expr: types.TypeAttribute
24+
value: types.TypeAttribute
25+
26+
def __bool__(self):
27+
return False
28+
29+
def __str__(self):
30+
return f"expected {self.expr}, got {self.value}"
31+
32+
33+
@dataclass
34+
class TypeResolution:
35+
vars: dict[types.TypeVar, types.TypeAttribute] = field(default_factory=dict)
36+
37+
def substitute(self, typ: types.TypeAttribute) -> types.TypeAttribute:
38+
if isinstance(typ, types.TypeVar):
39+
return self.vars.get(typ, typ)
40+
elif isinstance(typ, types.Generic):
41+
return types.Generic(
42+
typ.body, *tuple(self.substitute(var) for var in typ.vars)
43+
)
44+
elif isinstance(typ, types.Union):
45+
return types.Union(self.substitute(t) for t in typ.types)
46+
return typ
47+
48+
def solve(
49+
self, annot: types.TypeAttribute, value: types.TypeAttribute
50+
) -> TypeResolutionResult:
51+
if isinstance(annot, types.TypeVar):
52+
return self.solve_TypeVar(annot, value)
53+
elif isinstance(annot, types.Generic):
54+
return self.solve_Generic(annot, value)
55+
elif isinstance(annot, types.Union):
56+
return self.solve_Union(annot, value)
57+
58+
if annot.is_subseteq(value):
59+
return Ok
60+
else:
61+
return ResolutionError(annot, value)
62+
63+
def solve_TypeVar(self, annot: types.TypeVar, value: types.TypeAttribute):
64+
if annot in self.vars:
65+
if value.is_subseteq(self.vars[annot]):
66+
self.vars[annot] = value
67+
elif self.vars[annot].is_subseteq(value):
68+
pass
69+
else:
70+
return ResolutionError(annot, value)
71+
else:
72+
self.vars[annot] = value
73+
return Ok
74+
75+
def solve_Generic(self, annot: types.Generic, value: types.TypeAttribute):
76+
if not isinstance(value, types.Generic):
77+
return ResolutionError(annot, value)
78+
79+
if not value.body.is_subseteq(annot.body):
80+
return ResolutionError(annot.body, value.body)
81+
82+
for var, val in zip(annot.vars, value.vars):
83+
result = self.solve(var, val)
84+
if not result:
85+
return result
86+
87+
if not annot.vararg:
88+
return Ok
89+
90+
for val in value.vars[len(annot.vars) :]:
91+
result = self.solve(annot.vararg.typ, val)
92+
if not result:
93+
return result
94+
return Ok
95+
96+
def solve_Union(self, annot: types.Union, value: types.TypeAttribute):
97+
for typ in annot.types:
98+
result = self.solve(typ, value)
99+
if result:
100+
return Ok
101+
return ResolutionError(annot, value)

src/kirin/dialects/py/stmts/typeinfer.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,6 @@ def alias(
2121
) -> StatementResult[types.TypeAttribute]:
2222
return (frame.get(stmt.value),) # just forward the type
2323

24-
@impl(py.NewTuple)
25-
def new_tuple(
26-
self, interp, frame: Frame[types.TypeAttribute], stmt: py.NewTuple
27-
) -> StatementResult[types.TypeAttribute]:
28-
return (types.Tuple.where(frame.get_values(stmt.args)),) # make 3.10 happy
29-
3024
@impl(py.Add, types.Float, types.Float)
3125
@impl(py.Add, types.Float, types.Int)
3226
@impl(py.Add, types.Int, types.Float)

0 commit comments

Comments
 (0)