Skip to content

Commit c1d7c56

Browse files
authored
rework interp (part I) (#380)
this PR cleans up the interpreter registry and simplifies `MethodTable` and `impl`. We will do a rewrite of the interpreter interface in the following PR.
1 parent f15fc1f commit c1d7c56

File tree

18 files changed

+329
-232
lines changed

18 files changed

+329
-232
lines changed

src/kirin/analysis/typeinfer/analysis.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from kirin import ir, types, interp
44
from kirin.decl import fields
55
from kirin.analysis import const
6-
from kirin.interp.impl import Signature
76
from kirin.analysis.forward import Forward, ForwardFrame
87

98
from .solve import TypeResolution
@@ -41,15 +40,15 @@ def run_analysis(
4140
# value (which is a type) to determine the method dispatch.
4241
def build_signature(
4342
self, frame: ForwardFrame[types.TypeAttribute], stmt: ir.Statement
44-
) -> Signature:
43+
) -> interp.Signature:
4544
_args = ()
4645
for x in frame.get_values(stmt.args):
4746
# TODO: remove this after we have multiple dispatch...
4847
if isinstance(x, types.Generic):
4948
_args += (x.body,)
5049
else:
5150
_args += (x,)
52-
return Signature(stmt.__class__, _args)
51+
return interp.Signature(stmt.__class__, _args)
5352

5453
def eval_stmt_fallback(
5554
self, frame: ForwardFrame[types.TypeAttribute], stmt: ir.Statement

src/kirin/dialects/func/emit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def emit_function(
1919
):
2020
fn_args = stmt.body.blocks[0].args[1:]
2121
argnames = tuple(interp.ssa_id[arg] for arg in fn_args)
22-
argtypes = tuple(interp.emit_attribute(x.type) for x in fn_args)
22+
argtypes = tuple(interp.emit_attribute(frame, x.type) for x in fn_args)
2323
args = [f"{name}::{type}" for name, type in zip(argnames, argtypes)]
2424
interp.write(f"function {stmt.sym_name}({', '.join(args)})")
2525
frame.indent += 1

src/kirin/dialects/ilist/constprop.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def detect_purity(
2323

2424
@impl(Map)
2525
@impl(ForEach)
26-
def one_args(self, interp_: const.Propagate, frame: const.Frame, stmt: Map):
26+
def one_args(
27+
self, interp_: const.Propagate, frame: const.Frame, stmt: ForEach | Map
28+
):
2729
fn, collection = frame.get(stmt.fn), frame.get(stmt.collection)
2830

2931
# 1. if the function is a constant method, and the method is pure, then the map is pure

src/kirin/dialects/py/constant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,4 @@ class JuliaTable(interp.MethodTable):
8181

8282
@interp.impl(Constant)
8383
def emit_Constant(self, emit: EmitJulia, frame: EmitStrFrame, stmt: Constant):
84-
return (emit.emit_attribute(stmt.value),)
84+
return (emit.emit_attribute(frame, stmt.value),)

src/kirin/dialects/scf/typeinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def if_else_(
2121
frame.set(
2222
stmt.cond, frame.get(stmt.cond).meet(types.Bool)
2323
) # set cond backwards
24-
return super().if_else(self, interp_, frame, stmt)
24+
return super().if_else(interp_, frame, stmt)
2525

2626
@interp.impl(For)
2727
def for_loop(

src/kirin/emit/abc.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,16 @@ def run_ssacfg_region(
4545
frame.block_ref[succ.block] = block_header
4646
return ()
4747

48-
def emit_attribute(self, attr: ir.Attribute) -> ValueType:
48+
def emit_attribute(self, frame: FrameType, attr: ir.Attribute) -> ValueType:
4949
return getattr(
5050
self, f"emit_type_{type(attr).__name__}", self.emit_attribute_fallback
51-
)(attr)
51+
)(frame, attr)
5252

53-
def emit_attribute_fallback(self, attr: ir.Attribute) -> ValueType:
54-
if (method := self.registry.attributes.get(type(attr))) is not None:
55-
return method(self, attr)
53+
def emit_attribute_fallback(
54+
self, frame: FrameType, attr: ir.Attribute
55+
) -> ValueType:
56+
if (method := self.registry.get(interp.Signature(type(attr)))) is not None:
57+
return method(self, frame, attr)
5658
raise NotImplementedError(f"Attribute {type(attr)} not implemented")
5759

5860
def emit_stmt_begin(self, frame: FrameType, stmt: ir.Statement) -> None:

src/kirin/emit/abc.pyi

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,21 @@ class EmitABC(interp.BaseInterpreter[FrameType, ValueType]):
2424
def run_ssacfg_region(
2525
self, frame: FrameType, region: ir.Region, args: tuple[ValueType, ...]
2626
) -> tuple[ValueType, ...]: ...
27-
def emit_attribute(self, attr: ir.Attribute) -> ValueType: ...
28-
def emit_type_Any(self, attr: types.AnyType) -> ValueType: ...
29-
def emit_type_Bottom(self, attr: types.BottomType) -> ValueType: ...
30-
def emit_type_Literal(self, attr: types.Literal) -> ValueType: ...
31-
def emit_type_Union(self, attr: types.Union) -> ValueType: ...
32-
def emit_type_TypeVar(self, attr: types.TypeVar) -> ValueType: ...
33-
def emit_type_Vararg(self, attr: types.Vararg) -> ValueType: ...
34-
def emit_type_Generic(self, attr: types.Generic) -> ValueType: ...
35-
def emit_type_PyClass(self, attr: types.PyClass) -> ValueType: ...
36-
def emit_type_PyAttr(self, attr: ir.PyAttr) -> ValueType: ...
37-
def emit_attribute_fallback(self, attr: ir.Attribute) -> ValueType: ...
27+
def emit_attribute(self, frame: FrameType, attr: ir.Attribute) -> ValueType: ...
28+
def emit_type_Any(self, frame: FrameType, attr: types.AnyType) -> ValueType: ...
29+
def emit_type_Bottom(
30+
self, frame: FrameType, attr: types.BottomType
31+
) -> ValueType: ...
32+
def emit_type_Literal(self, frame: FrameType, attr: types.Literal) -> ValueType: ...
33+
def emit_type_Union(self, frame: FrameType, attr: types.Union) -> ValueType: ...
34+
def emit_type_TypeVar(self, frame: FrameType, attr: types.TypeVar) -> ValueType: ...
35+
def emit_type_Vararg(self, frame: FrameType, attr: types.Vararg) -> ValueType: ...
36+
def emit_type_Generic(self, frame: FrameType, attr: types.Generic) -> ValueType: ...
37+
def emit_type_PyClass(self, frame: FrameType, attr: types.PyClass) -> ValueType: ...
38+
def emit_type_PyAttr(self, frame: FrameType, attr: ir.PyAttr) -> ValueType: ...
39+
def emit_attribute_fallback(
40+
self, frame: FrameType, attr: ir.Attribute
41+
) -> ValueType: ...
3842
def emit_stmt_begin(self, frame: FrameType, stmt: ir.Statement) -> None: ...
3943
def emit_stmt_end(self, frame: FrameType, stmt: ir.Statement) -> None: ...
4044
def emit_block_begin(self, frame: FrameType, block: ir.Block) -> None: ...

src/kirin/emit/julia.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def emit_block_begin(self, frame: EmitStrFrame, block: Block) -> None:
2929
self.newline(frame)
3030
self.write(f"@label {block_id};")
3131

32-
def emit_type_PyClass(self, attr: PyClass) -> str:
32+
def emit_type_PyClass(self, frame: EmitStrFrame, attr: PyClass) -> str:
3333
return self.PYTYPE_MAP.get(attr.typ, "Any")
3434

3535
def write_assign(self, frame: EmitStrFrame, result: ir.SSAValue, *args):
@@ -54,7 +54,7 @@ def emit_binaryop(
5454
),
5555
)
5656

57-
def emit_type_PyAttr(self, attr: ir.PyAttr) -> str:
57+
def emit_type_PyAttr(self, frame: EmitStrFrame, attr: ir.PyAttr) -> str:
5858
if isinstance(attr.data, (int, float)):
5959
return repr(attr.data)
6060
elif isinstance(attr.data, str):

src/kirin/interp/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
"""
1717

1818
from .base import BaseInterpreter as BaseInterpreter
19-
from .impl import ImplDef as ImplDef, Signature as Signature, impl as impl
2019
from .frame import Frame as Frame, FrameABC as FrameABC
2120
from .state import InterpreterState as InterpreterState
22-
from .table import MethodTable as MethodTable
21+
from .table import Signature as Signature, MethodTable as MethodTable, impl as impl
2322
from .value import (
2423
Successor as Successor,
2524
YieldValue as YieldValue,
@@ -32,6 +31,11 @@
3231
AbstractInterpreter as AbstractInterpreter,
3332
)
3433
from .concrete import Interpreter as Interpreter
34+
from .undefined import (
35+
Undefined as Undefined,
36+
UndefinedType as UndefinedType,
37+
is_undefined as is_undefined,
38+
)
3539
from .exceptions import (
3640
InterpreterError as InterpreterError,
3741
FuelExhaustedError as FuelExhaustedError,

src/kirin/interp/base.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,9 @@
1+
from __future__ import annotations
2+
13
import sys
24
from abc import ABC, ABCMeta, abstractmethod
35
from enum import Enum
4-
from typing import (
5-
TYPE_CHECKING,
6-
Any,
7-
Generic,
8-
TypeVar,
9-
ClassVar,
10-
Optional,
11-
Sequence,
12-
Generator,
13-
)
6+
from typing import Any, Generic, TypeVar, ClassVar, Optional, Sequence, Generator
147
from contextlib import contextmanager
158
from dataclasses import field, dataclass
169

@@ -19,15 +12,12 @@
1912
from kirin.ir import Block, Method, Region, Statement, DialectGroup, traits
2013
from kirin.exception import KIRIN_INTERP_STATE
2114

22-
from .impl import Signature
2315
from .frame import FrameABC
2416
from .state import InterpreterState
17+
from .table import Signature, BoundedDef
2518
from .value import Successor, ReturnValue, SpecialValue, StatementResult
2619
from .exceptions import InterpreterError
2720

28-
if TYPE_CHECKING:
29-
from kirin.registry import StatementImpl, InterpreterRegistry
30-
3121
ValueType = TypeVar("ValueType")
3222
FrameType = TypeVar("FrameType", bound=FrameABC)
3323

@@ -57,9 +47,6 @@ class BaseInterpreter(ABC, Generic[FrameType, ValueType], metaclass=InterpreterM
5747
keys: ClassVar[list[str]]
5848
"""The name of the interpreter to select from dialects by order.
5949
"""
60-
Frame: ClassVar[type[FrameABC]] = field(init=False)
61-
"""The type of the frame to use for this interpreter.
62-
"""
6350
void: ValueType = field(init=False)
6451
"""What to return when the interpreter evaluates nothing.
6552
"""
@@ -80,7 +67,7 @@ class BaseInterpreter(ABC, Generic[FrameType, ValueType], metaclass=InterpreterM
8067
"""
8168

8269
# global states
83-
registry: "InterpreterRegistry" = field(init=False, compare=False)
70+
registry: dict[Signature, BoundedDef] = field(init=False, compare=False)
8471
"""The interpreter registry.
8572
"""
8673
symbol_table: dict[str, Statement] = field(init=False, compare=False)
@@ -433,7 +420,7 @@ def build_signature(self, frame: FrameType, stmt: Statement) -> "Signature":
433420

434421
def lookup_registry(
435422
self, frame: FrameType, stmt: Statement
436-
) -> Optional["StatementImpl[Self, FrameType]"]:
423+
) -> Optional[BoundedDef]:
437424
"""Lookup the statement implementation in the registry.
438425
439426
Args:
@@ -444,10 +431,10 @@ def lookup_registry(
444431
Optional[StatementImpl]: the statement implementation if found, None otherwise.
445432
"""
446433
sig = self.build_signature(frame, stmt)
447-
if sig in self.registry.statements:
448-
return self.registry.statements[sig]
449-
elif (class_sig := Signature(stmt.__class__)) in self.registry.statements:
450-
return self.registry.statements[class_sig]
434+
if sig in self.registry:
435+
return self.registry[sig]
436+
elif (method := self.registry.get(Signature(stmt.__class__))) is not None:
437+
return method
451438
return
452439

453440
@abstractmethod

0 commit comments

Comments
 (0)