From c54b118396f06fb6bcc706d1226d8909a038c743 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Tue, 28 Oct 2025 16:59:47 -0400 Subject: [PATCH 1/5] move Julia's symboltable to `EmitABC` to be reused. Rename to `EmitTable`. --- src/kirin/emit/abc.py | 34 +++++++++++++++++++++++++++++++++- src/kirin/emit/julia.py | 31 ------------------------------- 2 files changed, 33 insertions(+), 32 deletions(-) diff --git a/src/kirin/emit/abc.py b/src/kirin/emit/abc.py index 05b534fc5..cd10e038e 100644 --- a/src/kirin/emit/abc.py +++ b/src/kirin/emit/abc.py @@ -2,10 +2,11 @@ from abc import ABC, abstractmethod from typing import TypeVar -from dataclasses import dataclass +from dataclasses import field, dataclass from kirin import ir from kirin.interp import Frame, abc +from kirin.idtable import IdTable TargetType = TypeVar("TargetType") @@ -18,11 +19,42 @@ class EmitFrame(Frame[TargetType]): CodeGenFrameType = TypeVar("CodeGenFrameType", bound=EmitFrame) +@dataclass +class EmitTable(IdTable[ir.Statement]): + + def add(self, value: ir.Statement) -> str: + id = self.next_id + if (trait := value.get_trait(ir.SymbolOpInterface)) is not None: + value_name = trait.get_sym_name(value).unwrap() + curr_ind = self.name_count.get(value_name, 0) + suffix = f"_{curr_ind}" if curr_ind != 0 else "" + self.name_count[value_name] = curr_ind + 1 + name = self.prefix + value_name + suffix + self.table[value] = name + else: + name = f"{self.prefix}{self.prefix_if_none}{id}" + self.next_id += 1 + self.table[value] = name + return name + + def __getitem__(self, value: ir.Statement) -> str: + if value in self.table: + return self.table[value] + raise KeyError(f"Symbol {value} not found in SymbolTable") + + def get(self, value: ir.Statement, default: str | None = None) -> str | None: + if value in self.table: + return self.table[value] + return default + + @dataclass class EmitABC(abc.InterpreterABC[CodeGenFrameType, TargetType], ABC): + callables: EmitTable = field(init=False) def __init_subclass__(cls) -> None: super().__init_subclass__() + cls.callables = EmitTable(prefix="_callable_") for each in getattr(cls, "keys", ()): if not each.startswith("emit."): raise ValueError(f"Key {each} cannot start with 'emit.'") diff --git a/src/kirin/emit/julia.py b/src/kirin/emit/julia.py index 859c03b62..714f5894f 100644 --- a/src/kirin/emit/julia.py +++ b/src/kirin/emit/julia.py @@ -13,35 +13,6 @@ IO_t = TypeVar("IO_t", bound=IO) -@dataclass -class SymbolTable(IdTable[ir.Statement]): - - def add(self, value: ir.Statement) -> str: - id = self.next_id - if (trait := value.get_trait(ir.SymbolOpInterface)) is not None: - value_name = trait.get_sym_name(value).unwrap() - curr_ind = self.name_count.get(value_name, 0) - suffix = f"_{curr_ind}" if curr_ind != 0 else "" - self.name_count[value_name] = curr_ind + 1 - name = self.prefix + value_name + suffix - self.table[value] = name - else: - name = f"{self.prefix}{self.prefix_if_none}{id}" - self.next_id += 1 - self.table[value] = name - return name - - def __getitem__(self, value: ir.Statement) -> str: - if value in self.table: - return self.table[value] - raise KeyError(f"Symbol {value} not found in SymbolTable") - - def get(self, value: ir.Statement, default: str | None = None) -> str | None: - if value in self.table: - return self.table[value] - return default - - @dataclass class JuliaFrame(EmitFrame[str], Generic[IO_t]): io: IO_t @@ -79,12 +50,10 @@ class Julia(EmitABC[JuliaFrame, str], Generic[IO_t]): # some states io: IO_t - callables: SymbolTable = field(init=False) callable_to_emit: WorkList[ir.Statement] = field(init=False) def initialize(self): super().initialize() - self.callables = SymbolTable(prefix="_callable_") self.callable_to_emit = WorkList() return self From 84506c9ad3e00ec9d77bdd4d543cbb3919b2b48f Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Tue, 28 Oct 2025 17:28:06 -0400 Subject: [PATCH 2/5] Moved worklist and callables into base `EmitABC`. --- src/kirin/emit/abc.py | 7 +++++-- src/kirin/emit/julia.py | 3 --- test/emit/julia_like.jl | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/kirin/emit/abc.py b/src/kirin/emit/abc.py index cd10e038e..9c70625a7 100644 --- a/src/kirin/emit/abc.py +++ b/src/kirin/emit/abc.py @@ -7,6 +7,7 @@ from kirin import ir from kirin.interp import Frame, abc from kirin.idtable import IdTable +from kirin.worklist import WorkList TargetType = TypeVar("TargetType") @@ -51,13 +52,15 @@ def get(self, value: ir.Statement, default: str | None = None) -> str | None: @dataclass class EmitABC(abc.InterpreterABC[CodeGenFrameType, TargetType], ABC): callables: EmitTable = field(init=False) + callable_to_emit: WorkList[ir.Statement] = field(init=False) def __init_subclass__(cls) -> None: super().__init_subclass__() - cls.callables = EmitTable(prefix="_callable_") + cls.callables = EmitTable(prefix="") + cls.callable_to_emit = WorkList() for each in getattr(cls, "keys", ()): if not each.startswith("emit."): - raise ValueError(f"Key {each} cannot start with 'emit.'") + raise ValueError(f"Key {each} does not start with 'emit.'") @abstractmethod def run(self, node: ir.Method | ir.Statement): ... diff --git a/src/kirin/emit/julia.py b/src/kirin/emit/julia.py index 714f5894f..46acee01d 100644 --- a/src/kirin/emit/julia.py +++ b/src/kirin/emit/julia.py @@ -6,7 +6,6 @@ from kirin import ir, interp from kirin.idtable import IdTable -from kirin.worklist import WorkList from .abc import EmitABC, EmitFrame @@ -50,11 +49,9 @@ class Julia(EmitABC[JuliaFrame, str], Generic[IO_t]): # some states io: IO_t - callable_to_emit: WorkList[ir.Statement] = field(init=False) def initialize(self): super().initialize() - self.callable_to_emit = WorkList() return self def initialize_frame( diff --git a/test/emit/julia_like.jl b/test/emit/julia_like.jl index ec410108c..426c04226 100644 --- a/test/emit/julia_like.jl +++ b/test/emit/julia_like.jl @@ -1,4 +1,4 @@ -function _callable_julia_like(ssa_x, ssa_y) +function julia_like(ssa_x, ssa_y) @label block_0 ssa_0 = 0:1:ssa_x local ssa_y_1 @@ -27,11 +27,11 @@ function _callable_julia_like(ssa_x, ssa_y) ssa_y_2 = ssa_y_3 end ssa_5 = (ssa_x + ssa_y_1) - ssa_6 = _callable_some_arith(ssa_5, 4.0) + ssa_6 = some_arith(ssa_5, 4.0) return ssa_6 end -function _callable_some_arith(ssa_x, ssa_y) +function some_arith(ssa_x, ssa_y) @label block_0 ssa_0 = (ssa_x + ssa_y) return ssa_0 From 3fd6dfbc063559ce943c2705509c4ec80f9c4b59 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Tue, 28 Oct 2025 17:50:03 -0400 Subject: [PATCH 3/5] moved `run` method to `EmitABC`. --- src/kirin/emit/abc.py | 17 ++++++++++++++--- src/kirin/emit/julia.py | 15 --------------- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/src/kirin/emit/abc.py b/src/kirin/emit/abc.py index 9c70625a7..9fd84f071 100644 --- a/src/kirin/emit/abc.py +++ b/src/kirin/emit/abc.py @@ -1,6 +1,6 @@ from __future__ import annotations -from abc import ABC, abstractmethod +from abc import ABC from typing import TypeVar from dataclasses import field, dataclass @@ -62,5 +62,16 @@ def __init_subclass__(cls) -> None: if not each.startswith("emit."): raise ValueError(f"Key {each} does not start with 'emit.'") - @abstractmethod - def run(self, node: ir.Method | ir.Statement): ... + def run(self, node: ir.Method | ir.Statement): + if isinstance(node, ir.Method): + node = node.code + + with self.eval_context(): + self.callables.add(node) + self.callable_to_emit.append(node) + while self.callable_to_emit: + callable = self.callable_to_emit.pop() + if callable is None: + break + self.eval(callable) + return diff --git a/src/kirin/emit/julia.py b/src/kirin/emit/julia.py index 46acee01d..ea2af3a7d 100644 --- a/src/kirin/emit/julia.py +++ b/src/kirin/emit/julia.py @@ -59,21 +59,6 @@ def initialize_frame( ) -> JuliaFrame: return JuliaFrame(node, self.io, has_parent_access=has_parent_access) - def run(self, node: ir.Method | ir.Statement): - if isinstance(node, ir.Method): - node = node.code - - with self.eval_context(): - self.callables.add(node) - self.callable_to_emit.append(node) - while self.callable_to_emit: - callable = self.callable_to_emit.pop() - if callable is None: - break - self.eval(callable) - self.io.flush() - return - def frame_call( self, frame: JuliaFrame, node: ir.Statement, *args: str, **kwargs: str ) -> str: From ac089ee93919ca3fc395106f2c13c6025b50d109 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Wed, 29 Oct 2025 10:39:00 -0400 Subject: [PATCH 4/5] Moved common `ssa_idtable` and `block_idtable` into base frame class. --- src/kirin/emit/abc.py | 19 ++++++++++++++++++- src/kirin/emit/julia.py | 15 ++++----------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/src/kirin/emit/abc.py b/src/kirin/emit/abc.py index 9fd84f071..b8edf87ee 100644 --- a/src/kirin/emit/abc.py +++ b/src/kirin/emit/abc.py @@ -2,6 +2,7 @@ from abc import ABC from typing import TypeVar +from contextlib import contextmanager from dataclasses import field, dataclass from kirin import ir @@ -14,7 +15,23 @@ @dataclass class EmitFrame(Frame[TargetType]): - pass + ssa: IdTable[ir.SSAValue] = field( + default_factory=lambda: IdTable[ir.SSAValue](prefix="ssa_"), + init=False, + ) + block: IdTable[ir.Block] = field( + default_factory=lambda: IdTable[ir.Block](prefix="block_"), + init=False, + ) + _indent: int = field(default=0, init=False) + + @contextmanager + def indent(self): + self._indent += 1 + try: + yield + finally: + self._indent -= 1 CodeGenFrameType = TypeVar("CodeGenFrameType", bound=EmitFrame) diff --git a/src/kirin/emit/julia.py b/src/kirin/emit/julia.py index ea2af3a7d..0cc86d9ff 100644 --- a/src/kirin/emit/julia.py +++ b/src/kirin/emit/julia.py @@ -1,11 +1,11 @@ from __future__ import annotations -from typing import IO, Generic, TypeVar +import sys +from typing import IO, Generic, TypeVar, cast from contextlib import contextmanager -from dataclasses import field, dataclass +from dataclasses import dataclass from kirin import ir, interp -from kirin.idtable import IdTable from .abc import EmitABC, EmitFrame @@ -14,14 +14,7 @@ @dataclass class JuliaFrame(EmitFrame[str], Generic[IO_t]): - io: IO_t - ssa: IdTable[ir.SSAValue] = field( - default_factory=lambda: IdTable[ir.SSAValue](prefix="ssa_") - ) - block: IdTable[ir.Block] = field( - default_factory=lambda: IdTable[ir.Block](prefix="block_") - ) - _indent: int = 0 + io: IO_t = cast(IO_t, sys.stdout) def write(self, value): self.io.write(value) From 2383e40ecdc81e2f526885da05ea8f21add4a633 Mon Sep 17 00:00:00 2001 From: Dennis Liew Date: Wed, 29 Oct 2025 13:35:10 -0400 Subject: [PATCH 5/5] Add reset method to EmitABC and implement in Julia class --- src/kirin/emit/abc.py | 8 +++++++- src/kirin/emit/julia.py | 4 ++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/kirin/emit/abc.py b/src/kirin/emit/abc.py index b8edf87ee..6f6f18ab8 100644 --- a/src/kirin/emit/abc.py +++ b/src/kirin/emit/abc.py @@ -1,6 +1,6 @@ from __future__ import annotations -from abc import ABC +from abc import ABC, abstractmethod from typing import TypeVar from contextlib import contextmanager from dataclasses import field, dataclass @@ -80,6 +80,7 @@ def __init_subclass__(cls) -> None: raise ValueError(f"Key {each} does not start with 'emit.'") def run(self, node: ir.Method | ir.Statement): + self.reset() if isinstance(node, ir.Method): node = node.code @@ -92,3 +93,8 @@ def run(self, node: ir.Method | ir.Statement): break self.eval(callable) return + + @abstractmethod + def reset(self): + """Reset any per-run state in the emitter.""" + pass diff --git a/src/kirin/emit/julia.py b/src/kirin/emit/julia.py index 0cc86d9ff..61973e913 100644 --- a/src/kirin/emit/julia.py +++ b/src/kirin/emit/julia.py @@ -62,3 +62,7 @@ def get_attribute(self, frame: JuliaFrame, node: ir.Attribute) -> str: if method is None: raise ValueError(f"Method not found for node: {node}") return method(self, frame, node) + + def reset(self): + self.io.truncate(0) + self.io.seek(0)