Skip to content

Commit db071a7

Browse files
authored
generalize Const -> Hinted (#136)
This PR generalize `ir.types.Const` to `ir.types.Hinted` allow more flexible data being stored in the type annotation to assist type inference, e.g inference results from const propagation.
1 parent 53c3024 commit db071a7

File tree

13 files changed

+66
-59
lines changed

13 files changed

+66
-59
lines changed

src/kirin/analysis/typeinfer/analysis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def build_signature(
1717
) -> Signature:
1818
_args = ()
1919
for x in frame.get_values(stmt.args):
20-
if isinstance(x, types.Const):
21-
_args += (x.typ,)
20+
if isinstance(x, types.Hinted):
21+
_args += (x.type,)
2222
elif isinstance(x, types.Generic):
2323
_args += (x.body,)
2424
else:
@@ -43,6 +43,6 @@ def run_method(
4343
if len(self.state.frames) < self.max_depth:
4444
# NOTE: widen method type here
4545
return self.run_callable(
46-
method.code, (types.Const(method, types.PyClass(ir.Method)),) + args
46+
method.code, (types.Hinted(types.PyClass(ir.Method), method),) + args
4747
)
4848
return types.Bottom

src/kirin/dialects/fcf/typeinfer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def fold(
3737
stmt: Foldl | Foldr,
3838
values: tuple[ir.types.TypeAttribute, ...],
3939
):
40-
if not isinstance(values[0], ir.types.Const):
40+
if not isinstance(values[0], ir.types.Hinted):
4141
return (stmt.result.type,) # give up on dynamic calls
4242

4343
fn: ir.Method = values[0].data
@@ -71,7 +71,7 @@ def map_list(
7171
stmt: Map,
7272
):
7373
fn_value = frame.get(stmt.fn)
74-
if not isinstance(fn_value, ir.types.Const):
74+
if not isinstance(fn_value, ir.types.Hinted):
7575
return (ir.types.List[ir.types.Any],) # give up on dynamic calls
7676

7777
fn: ir.Method = fn_value.data
@@ -92,7 +92,7 @@ def map_range(
9292
stmt: Map,
9393
):
9494
fn_value = frame.get(stmt.fn)
95-
if not isinstance(fn_value, ir.types.Const):
95+
if not isinstance(fn_value, ir.types.Hinted):
9696
return (ir.types.List,) # give up on dynamic calls
9797

9898
fn: ir.Method = fn_value.data
@@ -114,7 +114,7 @@ def scan(
114114
init = frame.get(stmt.init)
115115
coll = frame.get(stmt.coll)
116116

117-
if not isinstance(fn_value, ir.types.Const):
117+
if not isinstance(fn_value, ir.types.Hinted):
118118
return (ir.types.Tuple[init, ir.types.List[ir.types.Any]],)
119119

120120
fn: ir.Method = fn_value.data

src/kirin/dialects/func/typeinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def return_(self, interp: TypeInference, frame: Frame, stmt: Return) -> ReturnVa
3030
def call(self, interp: TypeInference, frame: Frame, stmt: Call):
3131
# give up on dynamic method calls
3232
callee = frame.get(stmt.callee)
33-
if not isinstance(callee, ir.types.Const):
33+
if not isinstance(callee, ir.types.Hinted):
3434
return (stmt.result.type,)
3535

3636
mt: ir.Method = callee.data

src/kirin/dialects/py/stmts/_stmts/constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def __init__(self, value: T | data.PyAttr[T]) -> None:
2323
value = data.PyAttr(value)
2424
super().__init__(
2525
properties={"value": value},
26-
result_types=(types.Const(value.data, value.type),),
26+
result_types=(types.Hinted(value.type, value.data),),
2727
)
2828

2929
def print_impl(self, printer: Printer) -> None:

src/kirin/dialects/py/stmts/_stmts/slice.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ def __init__(self, start: SSAValue, stop: SSAValue, step: SSAValue) -> None:
2626
if stop.type.is_subseteq(types.NoneType):
2727
result_type = types.Bottom
2828
else:
29-
result_type = types.Slice[types.widen_const(stop.type)]
29+
result_type = types.Slice[types.unwrap_hinted(stop.type)]
3030
else:
31-
result_type = types.Slice[types.widen_const(start.type)]
31+
result_type = types.Slice[types.unwrap_hinted(start.type)]
3232

3333
super().__init__(
3434
args=(start, stop, step),

src/kirin/dialects/py/stmts/typeinfer.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,8 @@ def getitem(
180180
stmt: py.GetItem,
181181
) -> StatementResult[types.TypeAttribute]:
182182
obj = frame.get(stmt.obj)
183-
if isinstance(obj, types.Const): # unwrap const
184-
obj = obj.typ
183+
if isinstance(obj, types.Hinted): # unwrap const
184+
obj = obj.type
185185
index: types.TypeAttribute = frame.get(stmt.index)
186186
# TODO: replace this when we can multiple dispatch
187187
if obj.is_subseteq(types.Tuple):
@@ -226,7 +226,7 @@ def getitem_tuple_index(
226226
obj: types.Generic,
227227
index: types.TypeAttribute,
228228
):
229-
if isinstance(index, types.Const): # const
229+
if isinstance(index, types.Hinted): # const
230230
if obj.vararg and index.data >= len(obj.vars):
231231
return (obj.vararg.typ,)
232232
elif index.data < len(obj.vars):
@@ -243,7 +243,7 @@ def getitem_tuple_slice(
243243
obj: types.Generic,
244244
index: types.TypeAttribute,
245245
):
246-
if isinstance(index, types.Const):
246+
if isinstance(index, types.Hinted):
247247
data: slice = index.data
248248
if obj.vararg and data.stop >= len(obj.vars):
249249
return (
@@ -319,13 +319,13 @@ def slice(
319319
) -> StatementResult[types.TypeAttribute]:
320320
start, stop, step = frame.get_values(stmt.args)
321321
if (
322-
isinstance(start, types.Const)
323-
and isinstance(stop, types.Const)
324-
and isinstance(step, types.Const)
322+
isinstance(start, types.Hinted)
323+
and isinstance(stop, types.Hinted)
324+
and isinstance(step, types.Hinted)
325325
and isinstance(stmt.result.type, types.TypeAttribute)
326326
):
327327
return (
328-
types.Const(slice(start.data, stop.data, step.data), stmt.result.type),
328+
types.Hinted(stmt.result.type, slice(start.data, stop.data, step.data)),
329329
)
330330

331331
return (stmt.result.type,)

src/kirin/emit/abc.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class EmitABC(interp.BaseInterpreter[FrameType, ValueType]):
3737
def emit_type_TypeVar(self, attr: ir.types.TypeVar) -> ValueType: ...
3838
def emit_type_Vararg(self, attr: ir.types.Vararg) -> ValueType: ...
3939
def emit_type_Generic(self, attr: ir.types.Generic) -> ValueType: ...
40-
def emit_type_Const(self, attr: ir.types.Const) -> ValueType: ...
40+
def emit_type_Const(self, attr: ir.types.Hinted) -> ValueType: ...
4141
def emit_type_PyClass(self, attr: ir.types.PyClass) -> ValueType: ...
4242
def emit_attribute_fallback(self, attr: ir.Attribute) -> ValueType: ...
4343
def emit_stmt_begin(self, frame: FrameType, stmt: ir.Statement) -> None: ...

src/kirin/ir/_types.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ from dataclasses import dataclass
22

33
from kirin.ir.attrs import Attribute
44
from kirin.ir.types import (
5-
Const,
65
Union,
6+
Hinted,
77
Generic,
88
Literal,
99
PyClass,
@@ -16,7 +16,7 @@ class _TypeAttribute(Attribute):
1616
def is_subseteq_Union(self, other: Union) -> bool: ...
1717
def is_subseteq_Literal(self, other: Literal) -> bool: ...
1818
def is_subseteq_TypeVar(self, other: TypeVar) -> bool: ...
19-
def is_subseteq_Const(self, other: Const) -> bool: ...
19+
def is_subseteq_Const(self, other: Hinted) -> bool: ...
2020
def is_subseteq_PyClass(self, other: PyClass) -> bool: ...
2121
def is_subseteq_Generic(self, other: Generic) -> bool: ...
2222
def is_subseteq_fallback(self, other: TypeAttribute) -> bool: ...

src/kirin/ir/types.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ def is_subseteq_Generic(self, other: "Generic") -> bool:
162162
def is_subseteq_TypeVar(self, other: "TypeVar") -> bool:
163163
return self.is_subseteq(other.bound)
164164

165-
def is_subseteq_Const(self, other: "Const") -> bool:
166-
return self.is_subseteq(other.typ)
165+
def is_subseteq_Const(self, other: "Hinted") -> bool:
166+
return self.is_subseteq(other.type)
167167

168168
def __hash__(self) -> int:
169169
return hash((PyClass, self.typ))
@@ -349,7 +349,7 @@ def __init__(
349349
self.body = body
350350
else:
351351
self.body = PyClass(body)
352-
self.vars, self.vararg = split_type_args(vars)
352+
self.vars, self.vararg = _split_type_args(vars)
353353

354354
def is_subseteq_Literal(self, other: Literal) -> bool:
355355
return False
@@ -408,7 +408,7 @@ def where(self, typ: TypeVarValue | tuple[TypeVarValue, ...]) -> "Generic":
408408
else:
409409
typs = (typ,)
410410

411-
args, vararg = split_type_args(typs)
411+
args, vararg = _split_type_args(typs)
412412
if self.vararg is None and vararg is None:
413413
assert len(args) <= len(
414414
self.vars
@@ -446,45 +446,52 @@ def where(self, typ: TypeVarValue | tuple[TypeVarValue, ...]) -> "Generic":
446446
raise TypeError("Type arguments do not match")
447447

448448

449-
ConstType = typing.TypeVar("ConstType")
449+
HintedData = typing.TypeVar("HintedData")
450450

451451

452452
@typing.final
453453
@dataclass
454-
class Const(TypeAttribute, typing.Generic[ConstType]):
455-
name = "Const"
456-
data: ConstType
457-
typ: TypeAttribute
454+
class Hinted(TypeAttribute, typing.Generic[HintedData]):
455+
"""Type wrapped with a hint.
456+
457+
`Hinted` is used to represent a type with additional data that can be used as
458+
a hint for type inference. The additional data is only used for specific type
459+
inference purposes, or improve certain type inference precision, it does not affect the
460+
order of types in the lattice.
461+
"""
458462

459-
def __init__(self, data: ConstType, typ: TypeAttribute | None = None):
463+
name = "Hinted"
464+
type: TypeAttribute
465+
data: HintedData
466+
467+
def __init__(self, type: TypeAttribute, data: HintedData):
460468
self.data = data
461-
if isinstance(typ, Const):
462-
typ = widen_const(typ)
463-
elif typ is None:
464-
typ = PyClass(type(data))
465-
self.typ = typ
469+
if isinstance(type, Hinted):
470+
type = type.type
471+
self.type = type
466472

467473
def is_equal(self, other: TypeAttribute) -> bool:
468474
return (
469-
isinstance(other, Const)
475+
isinstance(other, Hinted)
470476
and self.data == other.data
471-
and self.typ.is_equal(other.typ)
477+
and self.type.is_equal(other.type)
472478
)
473479

474480
def is_subseteq_fallback(self, other: TypeAttribute) -> bool:
475-
return self.typ.is_subseteq(other)
481+
return self.type.is_subseteq(other)
476482

477483
def __hash__(self) -> int:
478-
return hash(self.typ)
484+
return hash(self.type)
479485

480486
def print_impl(self, printer: Printer) -> None:
481487
printer.print_name(self, prefix="!")
482-
printer.plain_print("(", self.data, ", ")
483-
printer.print(self.typ)
488+
printer.plain_print("(")
489+
printer.print(self.type)
490+
printer.plain_print(", ", self.data)
484491
printer.plain_print(")")
485492

486493

487-
def split_type_args(
494+
def _split_type_args(
488495
args: tuple[TypeVarValue, ...]
489496
) -> tuple[tuple[TypeAttribute, ...], Vararg | None]:
490497
if args is None or len(args) == 0:
@@ -539,10 +546,10 @@ def hint2type(hint) -> TypeAttribute:
539546
return Generic(body, *params)
540547

541548

542-
def widen_const(typ: TypeAttribute) -> TypeAttribute:
543-
if isinstance(typ, Const):
544-
return typ.typ
545-
return typ
549+
def unwrap_hinted(hint: TypeAttribute) -> TypeAttribute:
550+
if isinstance(hint, Hinted):
551+
return hint.type
552+
return hint
546553

547554

548555
Any = AnyType()

test/analysis/dataflow/typeinfer/test_unstable.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,17 @@ def results_at(block_id, stmt_id):
2525
return stmt_at(block_id, stmt_id).results
2626

2727
assert [infer.results[result] for result in results_at(0, 0)] == [
28-
types.Const(1, types.Int)
28+
types.Hinted(types.Int, 1)
2929
]
3030
assert [infer.results[result] for result in results_at(0, 1)] == [types.Int]
3131
assert [infer.results[result] for result in results_at(0, 2)] == [
32-
types.Const(10, types.Int)
32+
types.Hinted(types.Int, 10)
3333
]
3434
assert [infer.results[result] for result in results_at(0, 3)] == [types.Bool]
3535

3636
assert [infer.results[result] for result in results_at(1, 0)] == [types.Int]
3737
assert [infer.results[result] for result in results_at(2, 0)] == [
38-
types.Const(1.2, types.Float)
38+
types.Hinted(types.Float, 1.2)
3939
]
4040
assert [infer.results[result] for result in results_at(2, 1)] == [types.Float]
4141

0 commit comments

Comments
 (0)