Skip to content

Commit 2992835

Browse files
authored
Moved common fields from concrete class to abstract Emit class (#555)
Simplified the code gen classes. Moved the symbol table to `EmitABC`. Moved `ssa` and `block` fields into base `EmitABC` class for reusability. Also moved `run` method. Changes made alongside QuEraComputing/bloqade-circuit#555
1 parent b727b55 commit 2992835

File tree

3 files changed

+84
-67
lines changed

3 files changed

+84
-67
lines changed

src/kirin/emit/abc.py

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,30 +2,99 @@
22

33
from abc import ABC, abstractmethod
44
from typing import TypeVar
5-
from dataclasses import dataclass
5+
from contextlib import contextmanager
6+
from dataclasses import field, dataclass
67

78
from kirin import ir
89
from kirin.interp import Frame, abc
10+
from kirin.idtable import IdTable
11+
from kirin.worklist import WorkList
912

1013
TargetType = TypeVar("TargetType")
1114

1215

1316
@dataclass
1417
class EmitFrame(Frame[TargetType]):
15-
pass
18+
ssa: IdTable[ir.SSAValue] = field(
19+
default_factory=lambda: IdTable[ir.SSAValue](prefix="ssa_"),
20+
init=False,
21+
)
22+
block: IdTable[ir.Block] = field(
23+
default_factory=lambda: IdTable[ir.Block](prefix="block_"),
24+
init=False,
25+
)
26+
_indent: int = field(default=0, init=False)
27+
28+
@contextmanager
29+
def indent(self):
30+
self._indent += 1
31+
try:
32+
yield
33+
finally:
34+
self._indent -= 1
1635

1736

1837
CodeGenFrameType = TypeVar("CodeGenFrameType", bound=EmitFrame)
1938

2039

40+
@dataclass
41+
class EmitTable(IdTable[ir.Statement]):
42+
43+
def add(self, value: ir.Statement) -> str:
44+
id = self.next_id
45+
if (trait := value.get_trait(ir.SymbolOpInterface)) is not None:
46+
value_name = trait.get_sym_name(value).unwrap()
47+
curr_ind = self.name_count.get(value_name, 0)
48+
suffix = f"_{curr_ind}" if curr_ind != 0 else ""
49+
self.name_count[value_name] = curr_ind + 1
50+
name = self.prefix + value_name + suffix
51+
self.table[value] = name
52+
else:
53+
name = f"{self.prefix}{self.prefix_if_none}{id}"
54+
self.next_id += 1
55+
self.table[value] = name
56+
return name
57+
58+
def __getitem__(self, value: ir.Statement) -> str:
59+
if value in self.table:
60+
return self.table[value]
61+
raise KeyError(f"Symbol {value} not found in SymbolTable")
62+
63+
def get(self, value: ir.Statement, default: str | None = None) -> str | None:
64+
if value in self.table:
65+
return self.table[value]
66+
return default
67+
68+
2169
@dataclass
2270
class EmitABC(abc.InterpreterABC[CodeGenFrameType, TargetType], ABC):
71+
callables: EmitTable = field(init=False)
72+
callable_to_emit: WorkList[ir.Statement] = field(init=False)
2373

2474
def __init_subclass__(cls) -> None:
2575
super().__init_subclass__()
76+
cls.callables = EmitTable(prefix="")
77+
cls.callable_to_emit = WorkList()
2678
for each in getattr(cls, "keys", ()):
2779
if not each.startswith("emit."):
28-
raise ValueError(f"Key {each} cannot start with 'emit.'")
80+
raise ValueError(f"Key {each} does not start with 'emit.'")
81+
82+
def run(self, node: ir.Method | ir.Statement):
83+
self.reset()
84+
if isinstance(node, ir.Method):
85+
node = node.code
86+
87+
with self.eval_context():
88+
self.callables.add(node)
89+
self.callable_to_emit.append(node)
90+
while self.callable_to_emit:
91+
callable = self.callable_to_emit.pop()
92+
if callable is None:
93+
break
94+
self.eval(callable)
95+
return
2996

3097
@abstractmethod
31-
def run(self, node: ir.Method | ir.Statement): ...
98+
def reset(self):
99+
"""Reset any per-run state in the emitter."""
100+
pass

src/kirin/emit/julia.py

Lines changed: 8 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,20 @@
11
from __future__ import annotations
22

3-
from typing import IO, Generic, TypeVar
3+
import sys
4+
from typing import IO, Generic, TypeVar, cast
45
from contextlib import contextmanager
5-
from dataclasses import field, dataclass
6+
from dataclasses import dataclass
67

78
from kirin import ir, interp
8-
from kirin.idtable import IdTable
9-
from kirin.worklist import WorkList
109

1110
from .abc import EmitABC, EmitFrame
1211

1312
IO_t = TypeVar("IO_t", bound=IO)
1413

1514

16-
@dataclass
17-
class SymbolTable(IdTable[ir.Statement]):
18-
19-
def add(self, value: ir.Statement) -> str:
20-
id = self.next_id
21-
if (trait := value.get_trait(ir.SymbolOpInterface)) is not None:
22-
value_name = trait.get_sym_name(value).unwrap()
23-
curr_ind = self.name_count.get(value_name, 0)
24-
suffix = f"_{curr_ind}" if curr_ind != 0 else ""
25-
self.name_count[value_name] = curr_ind + 1
26-
name = self.prefix + value_name + suffix
27-
self.table[value] = name
28-
else:
29-
name = f"{self.prefix}{self.prefix_if_none}{id}"
30-
self.next_id += 1
31-
self.table[value] = name
32-
return name
33-
34-
def __getitem__(self, value: ir.Statement) -> str:
35-
if value in self.table:
36-
return self.table[value]
37-
raise KeyError(f"Symbol {value} not found in SymbolTable")
38-
39-
def get(self, value: ir.Statement, default: str | None = None) -> str | None:
40-
if value in self.table:
41-
return self.table[value]
42-
return default
43-
44-
4515
@dataclass
4616
class JuliaFrame(EmitFrame[str], Generic[IO_t]):
47-
io: IO_t
48-
ssa: IdTable[ir.SSAValue] = field(
49-
default_factory=lambda: IdTable[ir.SSAValue](prefix="ssa_")
50-
)
51-
block: IdTable[ir.Block] = field(
52-
default_factory=lambda: IdTable[ir.Block](prefix="block_")
53-
)
54-
_indent: int = 0
17+
io: IO_t = cast(IO_t, sys.stdout)
5518

5619
def write(self, value):
5720
self.io.write(value)
@@ -79,35 +42,16 @@ class Julia(EmitABC[JuliaFrame, str], Generic[IO_t]):
7942

8043
# some states
8144
io: IO_t
82-
callables: SymbolTable = field(init=False)
83-
callable_to_emit: WorkList[ir.Statement] = field(init=False)
8445

8546
def initialize(self):
8647
super().initialize()
87-
self.callables = SymbolTable(prefix="_callable_")
88-
self.callable_to_emit = WorkList()
8948
return self
9049

9150
def initialize_frame(
9251
self, node: ir.Statement, *, has_parent_access: bool = False
9352
) -> JuliaFrame:
9453
return JuliaFrame(node, self.io, has_parent_access=has_parent_access)
9554

96-
def run(self, node: ir.Method | ir.Statement):
97-
if isinstance(node, ir.Method):
98-
node = node.code
99-
100-
with self.eval_context():
101-
self.callables.add(node)
102-
self.callable_to_emit.append(node)
103-
while self.callable_to_emit:
104-
callable = self.callable_to_emit.pop()
105-
if callable is None:
106-
break
107-
self.eval(callable)
108-
self.io.flush()
109-
return
110-
11155
def frame_call(
11256
self, frame: JuliaFrame, node: ir.Statement, *args: str, **kwargs: str
11357
) -> str:
@@ -118,3 +62,7 @@ def get_attribute(self, frame: JuliaFrame, node: ir.Attribute) -> str:
11862
if method is None:
11963
raise ValueError(f"Method not found for node: {node}")
12064
return method(self, frame, node)
65+
66+
def reset(self):
67+
self.io.truncate(0)
68+
self.io.seek(0)

test/emit/julia_like.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function _callable_julia_like(ssa_x, ssa_y)
1+
function julia_like(ssa_x, ssa_y)
22
@label block_0
33
ssa_0 = 0:1:ssa_x
44
local ssa_y_1
@@ -27,11 +27,11 @@ function _callable_julia_like(ssa_x, ssa_y)
2727
ssa_y_2 = ssa_y_3
2828
end
2929
ssa_5 = (ssa_x + ssa_y_1)
30-
ssa_6 = _callable_some_arith(ssa_5, 4.0)
30+
ssa_6 = some_arith(ssa_5, 4.0)
3131
return ssa_6
3232
end
3333

34-
function _callable_some_arith(ssa_x, ssa_y)
34+
function some_arith(ssa_x, ssa_y)
3535
@label block_0
3636
ssa_0 = (ssa_x + ssa_y)
3737
return ssa_0

0 commit comments

Comments
 (0)