diff --git a/src/kirin/emit/abc.py b/src/kirin/emit/abc.py index 05b534fc5..6f6f18ab8 100644 --- a/src/kirin/emit/abc.py +++ b/src/kirin/emit/abc.py @@ -2,30 +2,99 @@ from abc import ABC, abstractmethod from typing import TypeVar -from dataclasses import dataclass +from contextlib import contextmanager +from dataclasses import field, dataclass from kirin import ir from kirin.interp import Frame, abc +from kirin.idtable import IdTable +from kirin.worklist import WorkList TargetType = TypeVar("TargetType") @dataclass class EmitFrame(Frame[TargetType]): - pass + ssa: IdTable[ir.SSAValue] = field( + default_factory=lambda: IdTable[ir.SSAValue](prefix="ssa_"), + init=False, + ) + block: IdTable[ir.Block] = field( + default_factory=lambda: IdTable[ir.Block](prefix="block_"), + init=False, + ) + _indent: int = field(default=0, init=False) + + @contextmanager + def indent(self): + self._indent += 1 + try: + yield + finally: + self._indent -= 1 CodeGenFrameType = TypeVar("CodeGenFrameType", bound=EmitFrame) +@dataclass +class EmitTable(IdTable[ir.Statement]): + + def add(self, value: ir.Statement) -> str: + id = self.next_id + if (trait := value.get_trait(ir.SymbolOpInterface)) is not None: + value_name = trait.get_sym_name(value).unwrap() + curr_ind = self.name_count.get(value_name, 0) + suffix = f"_{curr_ind}" if curr_ind != 0 else "" + self.name_count[value_name] = curr_ind + 1 + name = self.prefix + value_name + suffix + self.table[value] = name + else: + name = f"{self.prefix}{self.prefix_if_none}{id}" + self.next_id += 1 + self.table[value] = name + return name + + def __getitem__(self, value: ir.Statement) -> str: + if value in self.table: + return self.table[value] + raise KeyError(f"Symbol {value} not found in SymbolTable") + + def get(self, value: ir.Statement, default: str | None = None) -> str | None: + if value in self.table: + return self.table[value] + return default + + @dataclass class EmitABC(abc.InterpreterABC[CodeGenFrameType, TargetType], ABC): + callables: EmitTable = field(init=False) + callable_to_emit: WorkList[ir.Statement] = field(init=False) def __init_subclass__(cls) -> None: super().__init_subclass__() + cls.callables = EmitTable(prefix="") + cls.callable_to_emit = WorkList() for each in getattr(cls, "keys", ()): if not each.startswith("emit."): - raise ValueError(f"Key {each} cannot start with 'emit.'") + raise ValueError(f"Key {each} does not start with 'emit.'") + + def run(self, node: ir.Method | ir.Statement): + self.reset() + if isinstance(node, ir.Method): + node = node.code + + with self.eval_context(): + self.callables.add(node) + self.callable_to_emit.append(node) + while self.callable_to_emit: + callable = self.callable_to_emit.pop() + if callable is None: + break + self.eval(callable) + return @abstractmethod - def run(self, node: ir.Method | ir.Statement): ... + def reset(self): + """Reset any per-run state in the emitter.""" + pass diff --git a/src/kirin/emit/julia.py b/src/kirin/emit/julia.py index 859c03b62..61973e913 100644 --- a/src/kirin/emit/julia.py +++ b/src/kirin/emit/julia.py @@ -1,57 +1,20 @@ from __future__ import annotations -from typing import IO, Generic, TypeVar +import sys +from typing import IO, Generic, TypeVar, cast from contextlib import contextmanager -from dataclasses import field, dataclass +from dataclasses import dataclass from kirin import ir, interp -from kirin.idtable import IdTable -from kirin.worklist import WorkList from .abc import EmitABC, EmitFrame IO_t = TypeVar("IO_t", bound=IO) -@dataclass -class SymbolTable(IdTable[ir.Statement]): - - def add(self, value: ir.Statement) -> str: - id = self.next_id - if (trait := value.get_trait(ir.SymbolOpInterface)) is not None: - value_name = trait.get_sym_name(value).unwrap() - curr_ind = self.name_count.get(value_name, 0) - suffix = f"_{curr_ind}" if curr_ind != 0 else "" - self.name_count[value_name] = curr_ind + 1 - name = self.prefix + value_name + suffix - self.table[value] = name - else: - name = f"{self.prefix}{self.prefix_if_none}{id}" - self.next_id += 1 - self.table[value] = name - return name - - def __getitem__(self, value: ir.Statement) -> str: - if value in self.table: - return self.table[value] - raise KeyError(f"Symbol {value} not found in SymbolTable") - - def get(self, value: ir.Statement, default: str | None = None) -> str | None: - if value in self.table: - return self.table[value] - return default - - @dataclass class JuliaFrame(EmitFrame[str], Generic[IO_t]): - io: IO_t - ssa: IdTable[ir.SSAValue] = field( - default_factory=lambda: IdTable[ir.SSAValue](prefix="ssa_") - ) - block: IdTable[ir.Block] = field( - default_factory=lambda: IdTable[ir.Block](prefix="block_") - ) - _indent: int = 0 + io: IO_t = cast(IO_t, sys.stdout) def write(self, value): self.io.write(value) @@ -79,13 +42,9 @@ class Julia(EmitABC[JuliaFrame, str], Generic[IO_t]): # some states io: IO_t - callables: SymbolTable = field(init=False) - callable_to_emit: WorkList[ir.Statement] = field(init=False) def initialize(self): super().initialize() - self.callables = SymbolTable(prefix="_callable_") - self.callable_to_emit = WorkList() return self def initialize_frame( @@ -93,21 +52,6 @@ def initialize_frame( ) -> JuliaFrame: return JuliaFrame(node, self.io, has_parent_access=has_parent_access) - def run(self, node: ir.Method | ir.Statement): - if isinstance(node, ir.Method): - node = node.code - - with self.eval_context(): - self.callables.add(node) - self.callable_to_emit.append(node) - while self.callable_to_emit: - callable = self.callable_to_emit.pop() - if callable is None: - break - self.eval(callable) - self.io.flush() - return - def frame_call( self, frame: JuliaFrame, node: ir.Statement, *args: str, **kwargs: str ) -> str: @@ -118,3 +62,7 @@ def get_attribute(self, frame: JuliaFrame, node: ir.Attribute) -> str: if method is None: raise ValueError(f"Method not found for node: {node}") return method(self, frame, node) + + def reset(self): + self.io.truncate(0) + self.io.seek(0) diff --git a/test/emit/julia_like.jl b/test/emit/julia_like.jl index ec410108c..426c04226 100644 --- a/test/emit/julia_like.jl +++ b/test/emit/julia_like.jl @@ -1,4 +1,4 @@ -function _callable_julia_like(ssa_x, ssa_y) +function julia_like(ssa_x, ssa_y) @label block_0 ssa_0 = 0:1:ssa_x local ssa_y_1 @@ -27,11 +27,11 @@ function _callable_julia_like(ssa_x, ssa_y) ssa_y_2 = ssa_y_3 end ssa_5 = (ssa_x + ssa_y_1) - ssa_6 = _callable_some_arith(ssa_5, 4.0) + ssa_6 = some_arith(ssa_5, 4.0) return ssa_6 end -function _callable_some_arith(ssa_x, ssa_y) +function some_arith(ssa_x, ssa_y) @label block_0 ssa_0 = (ssa_x + ssa_y) return ssa_0