Skip to content

Commit 670d612

Browse files
authored
new code generator framework (#401)
This implements the new `EmitABC` based on the new interpreter framework. While it does its job - emitting some target format in either IO or str or whatever else. It does not address the usecase that lowers between two Kirin dialect groups, I'm thinking support this use case in the lowering framework instead of emit because technically this is about constructing the same Kirin IR based on another Kirin IR so we can just reuse the lowering tools for constructing the Kirin IR and provide a visitor of the IR. Otherwise it is hard to address cases like "replacing a statement with 3 statements"
1 parent dfcd096 commit 670d612

File tree

21 files changed

+488
-12
lines changed

21 files changed

+488
-12
lines changed

src/kirin/dialects/debug.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import rich
44

5-
from kirin import ir, decl, types, interp, lowering
5+
from kirin import ir, decl, emit, types, interp, lowering
66

77
dialect = ir.Dialect("debug")
88

@@ -66,3 +66,12 @@ def info(self, interp: interp.Interpreter, frame: interp.Frame, stmt: Info):
6666
rich.print(
6767
"[dim]└───────────────────────────────────────────────────────────────[/dim]"
6868
)
69+
70+
71+
@dialect.register(key="emit.julia")
72+
class JuliaEmit(interp.MethodTable):
73+
@interp.impl(Info)
74+
def info(self, emit: emit.Julia, frame: emit.JuliaFrame, stmt: Info):
75+
msg = frame.get(stmt.msg)
76+
inputs = " ".join(frame.get(input) for input in stmt.inputs)
77+
frame.write_line(f'@info "{msg[1:-1]}" {inputs}'.strip())

src/kirin/dialects/func/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@
1717
FuncOpCallableInterface as FuncOpCallableInterface,
1818
)
1919
from kirin.dialects.func._dialect import dialect as dialect
20+
21+
from . import _julia as _julia

src/kirin/dialects/func/_julia.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from __future__ import annotations
2+
3+
from typing import IO, TypeVar
4+
5+
from kirin import emit, interp
6+
7+
from .stmts import Invoke, Return, Function
8+
from ._dialect import dialect
9+
10+
11+
@dialect.register(key="emit.julia")
12+
class Julia(interp.MethodTable):
13+
14+
IO_t = TypeVar("IO_t", bound=IO)
15+
16+
@interp.impl(Return)
17+
def return_(
18+
self, emit: emit.Julia[IO_t], frame: emit.JuliaFrame[IO_t], node: Return
19+
):
20+
value = frame.get(node.value)
21+
frame.write_line(f"return {value}")
22+
23+
@interp.impl(Invoke)
24+
def invoke(
25+
self, emit: emit.Julia[IO_t], frame: emit.JuliaFrame[IO_t], node: Invoke
26+
):
27+
func_name = emit.callables.get(node.callee.code)
28+
if func_name is None:
29+
emit.callable_to_emit.append(node.callee.code)
30+
func_name = emit.callables.add(node.callee.code)
31+
32+
_, call_expr = emit.call(
33+
node.callee.code, func_name, *frame.get_values(node.args)
34+
)
35+
frame.write_line(f"{frame.ssa[node.result]} = {call_expr}")
36+
return (frame.ssa[node.result],)
37+
38+
@interp.impl(Function)
39+
def function(
40+
self, emit_: emit.Julia[IO_t], frame: emit.JuliaFrame[IO_t], node: Function
41+
):
42+
func_name = emit_.callables[node]
43+
frame.set(node.body.blocks[0].args[0], func_name)
44+
argnames_: list[str] = []
45+
for arg in node.body.blocks[0].args[1:]:
46+
frame.set(arg, name := frame.ssa[arg])
47+
argnames_.append(name)
48+
49+
argnames = ", ".join(argnames_)
50+
frame.write_line(f"function {func_name}({argnames})")
51+
with frame.indent():
52+
for block in node.body.blocks:
53+
frame.current_block = block
54+
frame.write_line(f"@label {frame.block[block]}")
55+
for arg in block.args:
56+
frame.set(arg, frame.ssa[arg])
57+
58+
for stmt in block.stmts:
59+
frame.current_stmt = stmt
60+
stmt_results = emit_.frame_eval(frame, stmt)
61+
62+
match stmt_results:
63+
case tuple():
64+
frame.set_values(stmt._results, stmt_results)
65+
case _:
66+
continue
67+
frame.write_line("end\n")

src/kirin/dialects/ilist/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77

88
from . import (
9+
_julia as _julia,
910
interp as interp,
1011
rewrite as rewrite,
1112
lowering as lowering,

src/kirin/dialects/ilist/_julia.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from __future__ import annotations
2+
3+
from kirin import emit, interp
4+
5+
from .stmts import Range
6+
from ._dialect import dialect
7+
8+
9+
@dialect.register(key="emit.julia")
10+
class JuliaMethodTable(interp.MethodTable):
11+
12+
@interp.impl(Range)
13+
def range(self, emit_: emit.Julia, frame: emit.JuliaFrame, node: Range):
14+
start = frame.get(node.start)
15+
stop = frame.get(node.stop)
16+
step = frame.get(node.step)
17+
frame.write_line(f"{frame.ssa[node.result]} = {start}:{step}:{stop}")
18+
return (frame.ssa[node.result],)

src/kirin/dialects/py/base.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import ast
88

9-
from kirin import ir, lowering
9+
from kirin import ir, interp, lowering
1010

1111
dialect = ir.Dialect("py.base")
1212

@@ -28,3 +28,11 @@ def lower_Name(self, state: lowering.State, node: ast.Name) -> lowering.Result:
2828

2929
def lower_Expr(self, state: lowering.State, node: ast.Expr) -> lowering.Result:
3030
return state.parent.visit(state, node.value)
31+
32+
33+
@dialect.register(key="emit.julia")
34+
class PyAttrMethod(interp.MethodTable):
35+
36+
@interp.impl(ir.PyAttr)
37+
def py_attr(self, interp, frame: interp.Frame, node: ir.PyAttr):
38+
return repr(node.data)

src/kirin/dialects/py/binop/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@
1313
`Mod`, `Pow`, `LShift`, `RShift`, `BitOr`, `BitXor`, and `BitAnd` statements.
1414
"""
1515

16-
from . import interp as interp, lowering as lowering, typeinfer as typeinfer
16+
from . import (
17+
_julia as _julia,
18+
interp as interp,
19+
lowering as lowering,
20+
typeinfer as typeinfer,
21+
)
1722
from .stmts import * # noqa: F403
1823
from ._dialect import dialect as dialect
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from kirin import emit, interp
2+
3+
from .stmts import Add
4+
from ._dialect import dialect
5+
6+
7+
@dialect.register(key="emit.julia")
8+
class JuliaEmit(interp.MethodTable):
9+
@interp.impl(Add)
10+
def add(self, emit_: emit.Julia, frame: emit.JuliaFrame, node: Add):
11+
lhs = frame.get(node.lhs)
12+
rhs = frame.get(node.rhs)
13+
frame.write_line(f"{frame.ssa[node.result]} = ({lhs} + {rhs})")
14+
return (frame.ssa[node.result],)

src/kirin/dialects/py/cmp/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@
1111
`Gt`, `GtE`, `Is`, and `IsNot` statements.
1212
"""
1313

14-
from . import interp as interp, lowering as lowering
14+
from . import _julia as _julia, interp as interp, lowering as lowering
1515
from .stmts import * # noqa: F403
1616
from ._dialect import dialect as dialect
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from kirin import emit, interp
2+
3+
from .stmts import Eq
4+
from ._dialect import dialect
5+
6+
7+
@dialect.register(key="emit.julia")
8+
class JuliaEmit(interp.MethodTable):
9+
@interp.impl(Eq)
10+
def add(self, emit_: emit.Julia, frame: emit.JuliaFrame, node: Eq):
11+
lhs = frame.get(node.lhs)
12+
rhs = frame.get(node.rhs)
13+
frame.write_line(f"{frame.ssa[node.result]} = ({lhs} == {rhs})")
14+
return (frame.ssa[node.result],)

0 commit comments

Comments
 (0)