Skip to content

Commit d7ad7ef

Browse files
authored
add module dialect to support (de)serialization (#151)
adding the module dialect to support (de)serialization so that functions can be compiled into a serializable format (without `Method` object in the IR). To support runing this dialect, interpreter now contains a symbol table. Note that the symbol table does not implements a namespace, developers need to mangle the function name when rewrite your function into a serializable `module` dialect.
1 parent 083fa13 commit d7ad7ef

File tree

5 files changed

+231
-41
lines changed

5 files changed

+231
-41
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from typing import TYPE_CHECKING
2+
3+
if TYPE_CHECKING:
4+
from kirin.ir import Statement
5+
from kirin.print.printer import Printer
6+
7+
8+
def pprint_calllike(
9+
invoke_or_call: "Statement", callee: str, printer: "Printer"
10+
) -> None:
11+
with printer.rich(style="red"):
12+
printer.print_name(invoke_or_call)
13+
printer.plain_print(" ")
14+
15+
n_total = len(invoke_or_call.args)
16+
# if (callee := getattr(invoke_or_call, "callee", None)) is None:
17+
# raise ValueError(f"{invoke_or_call} does not have a callee")
18+
19+
# if isinstance(callee, SSAValue):
20+
# printer.plain_print(printer.state.ssa_id[callee])
21+
# elif isinstance(callee, Method):
22+
# printer.plain_print(callee.sym_name)
23+
# elif isinstance(callee, str):
24+
# printer.plain_print(callee)
25+
# else:
26+
# raise ValueError(f"Unknown callee type {type(callee)}")
27+
28+
printer.plain_print(callee)
29+
if (inputs := getattr(invoke_or_call, "inputs", None)) is None:
30+
raise ValueError(f"{invoke_or_call} does not have inputs")
31+
32+
if not isinstance(inputs, tuple):
33+
raise ValueError(f"inputs of {invoke_or_call} is not a tuple")
34+
35+
if (kwargs := getattr(invoke_or_call, "kwargs", None)) is None:
36+
raise ValueError(f"{invoke_or_call} does not have kwargs")
37+
38+
if not isinstance(kwargs, tuple):
39+
raise ValueError(f"kwargs of {invoke_or_call} is not a tuple")
40+
41+
positional = inputs[: n_total - len(kwargs)]
42+
kwargs = dict(
43+
zip(
44+
kwargs,
45+
inputs[n_total - len(kwargs) :],
46+
)
47+
)
48+
49+
printer.plain_print("(")
50+
printer.print_seq(positional)
51+
if kwargs and positional:
52+
printer.plain_print(", ")
53+
printer.print_mapping(kwargs, delim=", ")
54+
printer.plain_print(")")
55+
56+
with printer.rich(style="black"):
57+
printer.plain_print(" : ")
58+
printer.print_seq(
59+
[result.type for result in invoke_or_call._results],
60+
delim=", ",
61+
)

src/kirin/dialects/func/stmts.py

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
from typing import Union
2-
31
from kirin.ir import (
42
Pure,
53
Method,
@@ -23,43 +21,7 @@
2321
from kirin.dialects.func.attrs import Signature, MethodType
2422
from kirin.dialects.func.dialect import dialect
2523

26-
27-
def _print_invoke_or_call(
28-
invoke_or_call: Union["Invoke", "Call"], printer: Printer
29-
) -> None:
30-
with printer.rich(style="red"):
31-
printer.print_name(invoke_or_call)
32-
printer.plain_print(" ")
33-
34-
n_total = len(invoke_or_call.args)
35-
callee = invoke_or_call.callee
36-
if isinstance(callee, SSAValue):
37-
printer.plain_print(printer.state.ssa_id[callee])
38-
else:
39-
printer.plain_print(callee.sym_name)
40-
41-
inputs = invoke_or_call.inputs
42-
positional = inputs[: n_total - len(invoke_or_call.kwargs)]
43-
kwargs = dict(
44-
zip(
45-
invoke_or_call.kwargs,
46-
inputs[n_total - len(invoke_or_call.kwargs) :],
47-
)
48-
)
49-
50-
printer.plain_print("(")
51-
printer.print_seq(positional)
52-
if kwargs and positional:
53-
printer.plain_print(", ")
54-
printer.print_mapping(kwargs, delim=", ")
55-
printer.plain_print(")")
56-
57-
with printer.rich(style="black"):
58-
printer.plain_print(" : ")
59-
printer.print_seq(
60-
[result.type for result in invoke_or_call._results],
61-
delim=", ",
62-
)
24+
from .._pprint_helper import pprint_calllike
6325

6426

6527
class FuncOpCallableInterface(CallableStmtInterface["Function"]):
@@ -116,7 +78,7 @@ class Call(Statement):
11678
result: ResultValue = info.result()
11779

11880
def print_impl(self, printer: Printer) -> None:
119-
_print_invoke_or_call(self, printer)
81+
pprint_calllike(self, printer.state.ssa_id[self.callee], printer)
12082

12183

12284
@statement(dialect=dialect)
@@ -242,7 +204,7 @@ class Invoke(Statement):
242204
result: ResultValue = info.result()
243205

244206
def print_impl(self, printer: Printer) -> None:
245-
_print_invoke_or_call(self, printer)
207+
pprint_calllike(self, self.callee.sym_name, printer)
246208

247209
def verify(self) -> None:
248210
if self.kwargs:

src/kirin/dialects/module.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Module dialect provides a simple module
2+
that is roughly a list of function statements.
3+
4+
This dialect provides the dialect necessary for compiling a function into
5+
lower-level IR with all its callee functions.
6+
"""
7+
8+
from kirin import ir, interp
9+
from kirin.decl import info, statement
10+
from kirin.print import Printer
11+
from kirin.analysis import TypeInference
12+
from kirin.exceptions import InterpreterError, VerificationError
13+
14+
from ._pprint_helper import pprint_calllike
15+
16+
dialect = ir.Dialect("module")
17+
18+
19+
@statement(dialect=dialect)
20+
class Module(ir.Statement):
21+
traits = frozenset(
22+
{ir.IsolatedFromAbove(), ir.SymbolTable(), ir.SymbolOpInterface()}
23+
)
24+
sym_name: str = info.attribute(property=True)
25+
entry: str = info.attribute(property=True)
26+
body: ir.Region = info.region(multi=False)
27+
28+
29+
@statement(dialect=dialect)
30+
class Invoke(ir.Statement):
31+
"""A special statement that represents
32+
a function calling functions by symbol name.
33+
34+
Note:
35+
This statement is here for completeness, for interpretation,
36+
it is recommended to rewrite this statement into a `func.Invoke`
37+
after looking up the symbol table.
38+
"""
39+
40+
callee: str = info.attribute(property=True)
41+
inputs: tuple[ir.SSAValue, ...] = info.argument()
42+
kwargs: tuple[str, ...] = info.attribute(property=True)
43+
result: ir.ResultValue = info.result()
44+
45+
def print_impl(self, printer: Printer) -> None:
46+
pprint_calllike(self, self.callee, printer)
47+
48+
def verify(self) -> None:
49+
if self.kwargs:
50+
for name in self.kwargs:
51+
if name not in self.callee:
52+
raise VerificationError(
53+
self,
54+
f"method {self.callee} does not have argument {name}",
55+
)
56+
elif len(self.callee) - 1 != len(self.args):
57+
raise VerificationError(
58+
self,
59+
f"expected {len(self.callee)} arguments, got {len(self.args)}",
60+
)
61+
62+
63+
@dialect.register
64+
class Concrete(interp.MethodTable):
65+
66+
@interp.impl(Module)
67+
def interp_Module(
68+
self, interp: interp.Interpreter, frame: interp.Frame, stmt: Module
69+
):
70+
for stmt_ in stmt.body.blocks[0].stmts:
71+
if (trait := stmt.get_trait(ir.SymbolOpInterface)) is not None:
72+
interp.symbol_table[trait.get_sym_name(stmt_).data] = stmt_
73+
return ()
74+
75+
@interp.impl(Invoke)
76+
def interp_Invoke(
77+
self, interpreter: interp.Interpreter, frame: interp.Frame, stmt: Invoke
78+
):
79+
callee = interpreter.symbol_table.get(stmt.callee)
80+
if callee is None:
81+
raise InterpreterError(f"symbol {stmt.callee} not found")
82+
83+
trait = callee.get_trait(ir.CallableStmtInterface)
84+
if trait is None:
85+
raise InterpreterError(
86+
f"{stmt.callee} is not callable, got {callee.__class__.__name__}"
87+
)
88+
89+
body = trait.get_callable_region(callee)
90+
mt = ir.Method(
91+
mod=None,
92+
py_func=None,
93+
sym_name=stmt.callee,
94+
arg_names=[
95+
arg.name or str(idx) for idx, arg in enumerate(body.blocks[0].args)
96+
],
97+
dialects=interpreter.dialects,
98+
code=stmt,
99+
)
100+
return interpreter.run_method(mt, frame.get_values(stmt.inputs))
101+
102+
103+
@dialect.register(key="typeinfer")
104+
class TypeInfer(interp.MethodTable):
105+
106+
@interp.impl(Module)
107+
def typeinfer_Module(
108+
self, interp: TypeInference, frame: interp.Frame, stmt: Module
109+
):
110+
for stmt_ in stmt.body.blocks[0].stmts:
111+
if (trait := stmt.get_trait(ir.SymbolOpInterface)) is not None:
112+
interp.symbol_table[trait.get_sym_name(stmt_).data] = stmt_
113+
return ()
114+
115+
@interp.impl(Invoke)
116+
def typeinfer_Invoke(
117+
self, interp: TypeInference, frame: interp.Frame, stmt: Invoke
118+
):
119+
callee = interp.symbol_table.get(stmt.callee)
120+
if callee is None:
121+
return (ir.types.Bottom,)
122+
123+
trait = callee.get_trait(ir.CallableStmtInterface)
124+
if trait is None:
125+
return (ir.types.Bottom,)
126+
127+
body = trait.get_callable_region(callee)
128+
mt = ir.Method(
129+
mod=None,
130+
py_func=None,
131+
sym_name=stmt.callee,
132+
arg_names=[
133+
arg.name or str(idx) for idx, arg in enumerate(body.blocks[0].args)
134+
],
135+
dialects=interp.dialects,
136+
code=stmt,
137+
)
138+
interp.run_method(mt, mt.arg_types)
139+
return tuple(result.type for result in callee.results)

src/kirin/interp/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def __init__(
6666
self.bottom = bottom
6767

6868
self.registry = self.dialects.registry.interpreter(keys=self.keys)
69+
self.symbol_table: dict[str, Statement] = {}
6970
self.state: InterpreterState[FrameType] = InterpreterState()
7071
self.fuel = fuel
7172
self.max_depth = max_depth

src/kirin/symbol_table.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from typing import Generic, TypeVar
2+
from dataclasses import field, dataclass
3+
4+
T = TypeVar("T")
5+
6+
7+
@dataclass
8+
class SymbolTable(Generic[T]):
9+
names: dict[str, T] = field(default_factory=dict)
10+
"""The table that maps names to values."""
11+
prefix: str = field(default="", kw_only=True)
12+
name_count: dict[str, int] = field(default_factory=dict, kw_only=True)
13+
"""The count of names that have been requested."""
14+
15+
def __getitem__(self, name: str) -> T:
16+
return self.names[name]
17+
18+
def __contains__(self, name: str) -> bool:
19+
return name in self.names
20+
21+
def __setitem__(self, name: str, value: T) -> None:
22+
count = self.name_count.setdefault(name, 0)
23+
self.name_count[name] = count + 1
24+
self.names[f"{self.prefix}_{name}_{count}"] = value
25+
26+
def __delitem__(self, name: str) -> None:
27+
del self.names[name]

0 commit comments

Comments
 (0)