Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 73 additions & 4 deletions src/kirin/emit/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,99 @@

from abc import ABC, abstractmethod
from typing import TypeVar
from dataclasses import dataclass
from contextlib import contextmanager
from dataclasses import field, dataclass

from kirin import ir
from kirin.interp import Frame, abc
from kirin.idtable import IdTable
from kirin.worklist import WorkList

TargetType = TypeVar("TargetType")


@dataclass
class EmitFrame(Frame[TargetType]):
pass
ssa: IdTable[ir.SSAValue] = field(
default_factory=lambda: IdTable[ir.SSAValue](prefix="ssa_"),
init=False,
)
block: IdTable[ir.Block] = field(
default_factory=lambda: IdTable[ir.Block](prefix="block_"),
init=False,
)
_indent: int = field(default=0, init=False)

@contextmanager
def indent(self):
self._indent += 1
try:
yield
finally:
self._indent -= 1


CodeGenFrameType = TypeVar("CodeGenFrameType", bound=EmitFrame)


@dataclass
class EmitTable(IdTable[ir.Statement]):

def add(self, value: ir.Statement) -> str:
id = self.next_id
if (trait := value.get_trait(ir.SymbolOpInterface)) is not None:
value_name = trait.get_sym_name(value).unwrap()
curr_ind = self.name_count.get(value_name, 0)
suffix = f"_{curr_ind}" if curr_ind != 0 else ""
self.name_count[value_name] = curr_ind + 1
name = self.prefix + value_name + suffix
self.table[value] = name
else:
name = f"{self.prefix}{self.prefix_if_none}{id}"
self.next_id += 1
self.table[value] = name
return name

def __getitem__(self, value: ir.Statement) -> str:
if value in self.table:
return self.table[value]
raise KeyError(f"Symbol {value} not found in SymbolTable")

def get(self, value: ir.Statement, default: str | None = None) -> str | None:
if value in self.table:
return self.table[value]
return default


@dataclass
class EmitABC(abc.InterpreterABC[CodeGenFrameType, TargetType], ABC):
callables: EmitTable = field(init=False)
callable_to_emit: WorkList[ir.Statement] = field(init=False)

def __init_subclass__(cls) -> None:
super().__init_subclass__()
cls.callables = EmitTable(prefix="")
cls.callable_to_emit = WorkList()
for each in getattr(cls, "keys", ()):
if not each.startswith("emit."):
raise ValueError(f"Key {each} cannot start with 'emit.'")
raise ValueError(f"Key {each} does not start with 'emit.'")

def run(self, node: ir.Method | ir.Statement):
self.reset()
if isinstance(node, ir.Method):
node = node.code

with self.eval_context():
self.callables.add(node)
self.callable_to_emit.append(node)
while self.callable_to_emit:
callable = self.callable_to_emit.pop()
if callable is None:
break
self.eval(callable)
return

@abstractmethod
def run(self, node: ir.Method | ir.Statement): ...
def reset(self):
"""Reset any per-run state in the emitter."""
pass
68 changes: 8 additions & 60 deletions src/kirin/emit/julia.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,20 @@
from __future__ import annotations

from typing import IO, Generic, TypeVar
import sys
from typing import IO, Generic, TypeVar, cast
from contextlib import contextmanager
from dataclasses import field, dataclass
from dataclasses import dataclass

from kirin import ir, interp
from kirin.idtable import IdTable
from kirin.worklist import WorkList

from .abc import EmitABC, EmitFrame

IO_t = TypeVar("IO_t", bound=IO)


@dataclass
class SymbolTable(IdTable[ir.Statement]):

def add(self, value: ir.Statement) -> str:
id = self.next_id
if (trait := value.get_trait(ir.SymbolOpInterface)) is not None:
value_name = trait.get_sym_name(value).unwrap()
curr_ind = self.name_count.get(value_name, 0)
suffix = f"_{curr_ind}" if curr_ind != 0 else ""
self.name_count[value_name] = curr_ind + 1
name = self.prefix + value_name + suffix
self.table[value] = name
else:
name = f"{self.prefix}{self.prefix_if_none}{id}"
self.next_id += 1
self.table[value] = name
return name

def __getitem__(self, value: ir.Statement) -> str:
if value in self.table:
return self.table[value]
raise KeyError(f"Symbol {value} not found in SymbolTable")

def get(self, value: ir.Statement, default: str | None = None) -> str | None:
if value in self.table:
return self.table[value]
return default


@dataclass
class JuliaFrame(EmitFrame[str], Generic[IO_t]):
io: IO_t
ssa: IdTable[ir.SSAValue] = field(
default_factory=lambda: IdTable[ir.SSAValue](prefix="ssa_")
)
block: IdTable[ir.Block] = field(
default_factory=lambda: IdTable[ir.Block](prefix="block_")
)
_indent: int = 0
io: IO_t = cast(IO_t, sys.stdout)

def write(self, value):
self.io.write(value)
Expand Down Expand Up @@ -79,35 +42,16 @@ class Julia(EmitABC[JuliaFrame, str], Generic[IO_t]):

# some states
io: IO_t
callables: SymbolTable = field(init=False)
callable_to_emit: WorkList[ir.Statement] = field(init=False)

def initialize(self):
super().initialize()
self.callables = SymbolTable(prefix="_callable_")
self.callable_to_emit = WorkList()
return self

def initialize_frame(
self, node: ir.Statement, *, has_parent_access: bool = False
) -> JuliaFrame:
return JuliaFrame(node, self.io, has_parent_access=has_parent_access)

def run(self, node: ir.Method | ir.Statement):
if isinstance(node, ir.Method):
node = node.code

with self.eval_context():
self.callables.add(node)
self.callable_to_emit.append(node)
while self.callable_to_emit:
callable = self.callable_to_emit.pop()
if callable is None:
break
self.eval(callable)
self.io.flush()
return

def frame_call(
self, frame: JuliaFrame, node: ir.Statement, *args: str, **kwargs: str
) -> str:
Expand All @@ -118,3 +62,7 @@ def get_attribute(self, frame: JuliaFrame, node: ir.Attribute) -> str:
if method is None:
raise ValueError(f"Method not found for node: {node}")
return method(self, frame, node)

def reset(self):
self.io.truncate(0)
self.io.seek(0)
6 changes: 3 additions & 3 deletions test/emit/julia_like.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
function _callable_julia_like(ssa_x, ssa_y)
function julia_like(ssa_x, ssa_y)
@label block_0
ssa_0 = 0:1:ssa_x
local ssa_y_1
Expand Down Expand Up @@ -27,11 +27,11 @@ function _callable_julia_like(ssa_x, ssa_y)
ssa_y_2 = ssa_y_3
end
ssa_5 = (ssa_x + ssa_y_1)
ssa_6 = _callable_some_arith(ssa_5, 4.0)
ssa_6 = some_arith(ssa_5, 4.0)
return ssa_6
end

function _callable_some_arith(ssa_x, ssa_y)
function some_arith(ssa_x, ssa_y)
@label block_0
ssa_0 = (ssa_x + ssa_y)
return ssa_0
Expand Down