11from __future__ import annotations
22
3- from typing import IO , Generic , TypeVar
3+ import sys
4+ from typing import IO , Generic , TypeVar , cast
45from contextlib import contextmanager
5- from dataclasses import field , dataclass
6+ from dataclasses import dataclass
67
78from kirin import ir , interp
8- from kirin .idtable import IdTable
9- from kirin .worklist import WorkList
109
1110from .abc import EmitABC , EmitFrame
1211
1312IO_t = TypeVar ("IO_t" , bound = IO )
1413
1514
16- @dataclass
17- class SymbolTable (IdTable [ir .Statement ]):
18-
19- def add (self , value : ir .Statement ) -> str :
20- id = self .next_id
21- if (trait := value .get_trait (ir .SymbolOpInterface )) is not None :
22- value_name = trait .get_sym_name (value ).unwrap ()
23- curr_ind = self .name_count .get (value_name , 0 )
24- suffix = f"_{ curr_ind } " if curr_ind != 0 else ""
25- self .name_count [value_name ] = curr_ind + 1
26- name = self .prefix + value_name + suffix
27- self .table [value ] = name
28- else :
29- name = f"{ self .prefix } { self .prefix_if_none } { id } "
30- self .next_id += 1
31- self .table [value ] = name
32- return name
33-
34- def __getitem__ (self , value : ir .Statement ) -> str :
35- if value in self .table :
36- return self .table [value ]
37- raise KeyError (f"Symbol { value } not found in SymbolTable" )
38-
39- def get (self , value : ir .Statement , default : str | None = None ) -> str | None :
40- if value in self .table :
41- return self .table [value ]
42- return default
43-
44-
4515@dataclass
4616class JuliaFrame (EmitFrame [str ], Generic [IO_t ]):
47- io : IO_t
48- ssa : IdTable [ir .SSAValue ] = field (
49- default_factory = lambda : IdTable [ir .SSAValue ](prefix = "ssa_" )
50- )
51- block : IdTable [ir .Block ] = field (
52- default_factory = lambda : IdTable [ir .Block ](prefix = "block_" )
53- )
54- _indent : int = 0
17+ io : IO_t = cast (IO_t , sys .stdout )
5518
5619 def write (self , value ):
5720 self .io .write (value )
@@ -79,35 +42,16 @@ class Julia(EmitABC[JuliaFrame, str], Generic[IO_t]):
7942
8043 # some states
8144 io : IO_t
82- callables : SymbolTable = field (init = False )
83- callable_to_emit : WorkList [ir .Statement ] = field (init = False )
8445
8546 def initialize (self ):
8647 super ().initialize ()
87- self .callables = SymbolTable (prefix = "_callable_" )
88- self .callable_to_emit = WorkList ()
8948 return self
9049
9150 def initialize_frame (
9251 self , node : ir .Statement , * , has_parent_access : bool = False
9352 ) -> JuliaFrame :
9453 return JuliaFrame (node , self .io , has_parent_access = has_parent_access )
9554
96- def run (self , node : ir .Method | ir .Statement ):
97- if isinstance (node , ir .Method ):
98- node = node .code
99-
100- with self .eval_context ():
101- self .callables .add (node )
102- self .callable_to_emit .append (node )
103- while self .callable_to_emit :
104- callable = self .callable_to_emit .pop ()
105- if callable is None :
106- break
107- self .eval (callable )
108- self .io .flush ()
109- return
110-
11155 def frame_call (
11256 self , frame : JuliaFrame , node : ir .Statement , * args : str , ** kwargs : str
11357 ) -> str :
@@ -118,3 +62,7 @@ def get_attribute(self, frame: JuliaFrame, node: ir.Attribute) -> str:
11862 if method is None :
11963 raise ValueError (f"Method not found for node: { node } " )
12064 return method (self , frame , node )
65+
66+ def reset (self ):
67+ self .io .truncate (0 )
68+ self .io .seek (0 )
0 commit comments