Skip to content

Commit 083fa13

Browse files
authored
remove eval_stmt add run_stmt_fallback (#149)
We have `eval_stmt` and `run_stmt` which has been two confusing APIs. The original intention was to separate the stage when registering the source info (which we don't have yet), and when to actually lookup the registery. I think this is not necessary now, to change how to update the source info just overload `run_stmt` otherwise update lookup behaviour with `lookup_registry` or define fallback via `run_stmt_fallback` cc: @kaihsin
1 parent 3c8793b commit 083fa13

File tree

3 files changed

+31
-17
lines changed

3 files changed

+31
-17
lines changed

src/kirin/analysis/const/prop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _try_eval_const_pure(
4848
try:
4949
_frame = self.interp.new_frame(frame.code)
5050
_frame.set_values(stmt.args, tuple(x.data for x in values))
51-
value = self.interp.eval_stmt(_frame, stmt)
51+
value = self.interp.run_stmt(_frame, stmt)
5252
if isinstance(value, tuple):
5353
return tuple(JointResult(Value(each), Pure()) for each in value)
5454
elif isinstance(value, interp.ReturnValue):
@@ -64,7 +64,7 @@ def _try_eval_const_pure(
6464
pass
6565
return (self.bottom,)
6666

67-
def eval_stmt(
67+
def run_stmt(
6868
self, frame: ForwardFrame[JointResult, ExtraFrameInfo], stmt: ir.Statement
6969
) -> interp.StatementResult[JointResult]:
7070
if stmt.has_trait(ir.ConstantLike):

src/kirin/analysis/typeinfer/analysis.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,9 @@ def _unwrap(self, value: types.TypeAttribute) -> types.TypeAttribute:
3030
return value.body
3131
return value
3232

33-
def eval_stmt(
33+
def run_stmt_fallback(
3434
self, frame: ForwardFrame[types.TypeAttribute, None], stmt: ir.Statement
35-
) -> interp.StatementResult[types.TypeAttribute]:
36-
method = self.lookup_registry(frame, stmt)
37-
if method is not None:
38-
return method(self, frame, stmt)
39-
35+
) -> tuple[types.TypeAttribute, ...] | interp.SpecialResult[types.TypeAttribute]:
4036
resolve = TypeResolution()
4137
for arg, value in zip(stmt.args, frame.get_values(stmt.args)):
4238
resolve.solve(arg.type, value)

src/kirin/interp/base.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,9 @@ def permute_values(
198198
return args
199199

200200
def run_stmt(self, frame: FrameType, stmt: Statement) -> StatementResult[ValueType]:
201-
"""Run a statement within the current frame
201+
"""Run a statement within the current frame. This is the entry
202+
point of running a statement. It will look up the statement implementation
203+
in the dialect registry, or optionally call a fallback implementation.
202204
203205
Args:
204206
frame: the current frame
@@ -208,8 +210,11 @@ def run_stmt(self, frame: FrameType, stmt: Statement) -> StatementResult[ValueTy
208210
StatementResult: the result of running the statement
209211
210212
Note:
211-
In the case of implementing the fallback, subclass this method,
212-
and filter the statement type you want to handle.
213+
Overload this method for the following reasons:
214+
- to change the source tracking information
215+
- to take control of how to run a statement
216+
- to change the implementation lookup behavior that cannot acheive
217+
by overloading [`lookup_registry`][kirin.interp.base.BaseInterpreter.lookup_registry]
213218
214219
Example:
215220
* implement an interpreter that only handles MyStmt:
@@ -225,19 +230,32 @@ def run_stmt(self, frame: FrameType, stmt: Statement) -> StatementResult[ValueTy
225230
226231
"""
227232
# TODO: update tracking information
228-
return self.eval_stmt(frame, stmt)
229-
230-
def eval_stmt(
231-
self, frame: FrameType, stmt: Statement
232-
) -> StatementResult[ValueType]:
233-
"simply evaluate a statement"
234233
method = self.lookup_registry(frame, stmt)
235234
if method is not None:
236235
try:
237236
return method(self, frame, stmt)
238237
except InterpreterError as e:
239238
return Err(e, self.state.frames)
240239

240+
return self.run_stmt_fallback(frame, stmt)
241+
242+
def run_stmt_fallback(
243+
self, frame: FrameType, stmt: Statement
244+
) -> StatementResult[ValueType]:
245+
"""The fallback implementation of statements.
246+
247+
This is called when no implementation is found for the statement.
248+
249+
Args:
250+
frame: the current frame
251+
stmt: the statement to run
252+
253+
Returns:
254+
StatementResult: the result of running the statement
255+
256+
Note:
257+
Overload this method to provide a fallback implementation for statements.
258+
"""
241259
# NOTE: not using f-string here because 3.10 and 3.11 have
242260
# parser bug that doesn't allow f-string in raise statement
243261
raise ValueError(

0 commit comments

Comments
 (0)