Skip to content

Commit abad271

Browse files
authored
further simplify interpreter and use a more clarified name (#118)
This PR removes the fallback attached to each method table for associated dialects but instead the fallback is implemented inside the corresponding interpreter. This is because I don't see any use case that different dialect will use different fallback, and in fact no dialect is using this feature right now.
1 parent 5a3f9e2 commit abad271

File tree

22 files changed

+82
-104
lines changed

22 files changed

+82
-104
lines changed

example/beer/interp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
from stmts import Pour, Puke, Drink, NewBeer, RandomBranch
55
from dialect import dialect
66

7-
from kirin.interp import Successor, Interpreter, DialectInterpreter, impl
7+
from kirin.interp import Successor, Interpreter, MethodTable, impl
88

99

1010
@dialect.register
11-
class BeerInterpreter(DialectInterpreter):
11+
class BeerInterpreter(MethodTable):
1212

1313
@impl(NewBeer)
1414
def new_beer(self, interp: Interpreter, stmt: NewBeer, values: tuple):

example/beer/script.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1+
from stmts import Pour, Puke, Drink, NewBeer
12
from dialect import dialect
2-
from stmts import Drink, NewBeer, Pour, Puke
33

44
from interp import BeerInterpreter as BeerInterpreter
5+
from rewrite import RandomWalkBranch
56
from kirin.ir import dialect_group
67
from kirin.prelude import basic_no_opt
7-
from kirin.rewrite import Fixpoint, Walk
8-
from rewrite import RandomWalkBranch
8+
from kirin.rewrite import Walk, Fixpoint
99

1010

1111
# create our own beer dialect, it runs a random walk on the branches
@@ -38,6 +38,6 @@ def main(x):
3838
main.code.print()
3939
main(1) # execute the function
4040

41-
# for i in range(10):
42-
# print("iteration", i)
43-
# main(i) # now drink a random beer!
41+
for i in range(10):
42+
print("iteration", i)
43+
main(i) # now drink a random beer!

src/kirin/analysis/typeinfer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from kirin import ir
1+
from kirin import ir, interp
22
from kirin.ir import types
33
from kirin.ir.method import Method
44
from kirin.ir.nodes.stmt import Statement
@@ -7,7 +7,7 @@
77

88

99
class TypeInference(Forward[types.TypeAttribute]):
10-
keys = ["typeinfer", "typeinfer.default"]
10+
keys = ["typeinfer", "empty"]
1111
lattice = types.TypeAttribute
1212

1313
def build_signature(self, stmt: Statement, args: tuple):
@@ -26,6 +26,14 @@ def build_signature(self, stmt: Statement, args: tuple):
2626
tuple(_args),
2727
)
2828

29+
def eval_stmt(
30+
self, stmt: Statement, args: tuple[types.TypeAttribute, ...]
31+
) -> interp.Result:
32+
method = self.lookup_registry(stmt, args)
33+
if method is not None:
34+
return method(self, stmt, args)
35+
return tuple(result.type for result in stmt.results)
36+
2937
def run_method_region(
3038
self, mt: Method, body: Region, args: tuple[types.TypeAttribute, ...]
3139
) -> types.TypeAttribute:

src/kirin/dialects/cf/constprop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
1-
from kirin.interp import Successor, DialectInterpreter, impl
1+
from kirin.interp import Successor, MethodTable, impl
22
from kirin.analysis import const
33
from kirin.dialects.cf.stmts import Assert, Branch, ConditionalBranch
44
from kirin.dialects.cf.dialect import dialect
55

66

77
@dialect.register(key="constprop")
8-
class DialectConstProp(DialectInterpreter):
8+
class DialectConstProp(MethodTable):
99

1010
@impl(Assert)
1111
def assert_stmt(self, interp: const.Propagate, stmt: Assert, values):

src/kirin/dialects/cf/interp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
from kirin.interp import Err, Successor, Interpreter, DialectInterpreter, impl
1+
from kirin.interp import Err, Successor, Interpreter, MethodTable, impl
22
from kirin.dialects.cf.stmts import Assert, Branch, ConditionalBranch
33
from kirin.dialects.cf.dialect import dialect
44

55

66
@dialect.register
7-
class CfInterpreter(DialectInterpreter):
7+
class CfInterpreter(MethodTable):
88

99
@impl(Assert)
1010
def assert_stmt(self, interp: Interpreter, stmt: Assert, values):

src/kirin/dialects/cf/typeinfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
from kirin.ir import types
2-
from kirin.interp import Successor, DialectInterpreter, impl
2+
from kirin.interp import Successor, MethodTable, impl
33
from kirin.dialects.cf.stmts import Assert, Branch, ConditionalBranch
44
from kirin.analysis.typeinfer import TypeInference
55
from kirin.dialects.cf.dialect import dialect
66

77

88
@dialect.register(key="typeinfer")
9-
class TypeInfer(DialectInterpreter):
9+
class TypeInfer(MethodTable):
1010

1111
@impl(Assert)
1212
def assert_stmt(self, interp: TypeInference, stmt: Assert, values):

src/kirin/dialects/fcf/interp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from kirin import ir
2-
from kirin.interp import Interpreter, DialectInterpreter, impl
2+
from kirin.interp import Interpreter, MethodTable, impl
33
from kirin.dialects.fcf.stmts import Map, Scan, Foldl, Foldr
44
from kirin.dialects.fcf.dialect import dialect
55

66

77
@dialect.register
8-
class FCFInterpreter(DialectInterpreter):
8+
class FCFInterpreter(MethodTable):
99

1010
@impl(Foldl)
1111
def foldl(self, interp: Interpreter, stmt: Foldl, values: tuple):

src/kirin/dialects/fcf/typeinfer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from typing import Callable, Iterable
22

33
from kirin import ir
4-
from kirin.interp import DialectInterpreter, impl
4+
from kirin.interp import MethodTable, impl
55
from kirin.analysis.typeinfer import TypeInference
66
from kirin.dialects.fcf.stmts import Map, Scan, Foldl, Foldr
77
from kirin.dialects.fcf.dialect import dialect
88

99

1010
@dialect.register(key="typeinfer")
11-
class TypeInfer(DialectInterpreter):
11+
class TypeInfer(MethodTable):
1212

1313
@impl(Foldl)
1414
def foldl(

src/kirin/dialects/func/constprop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
from typing import Iterable
22

33
from kirin import ir
4-
from kirin.interp import Result, ReturnValue, DialectInterpreter, impl
4+
from kirin.interp import Result, MethodTable, ReturnValue, impl
55
from kirin.analysis import const
66
from kirin.dialects.func.stmts import Call, Invoke, Lambda, Return, GetField
77
from kirin.dialects.func.dialect import dialect
88

99

1010
@dialect.register(key="constprop")
11-
class DialectConstProp(DialectInterpreter):
11+
class DialectConstProp(MethodTable):
1212

1313
@impl(Return)
1414
def return_(

src/kirin/dialects/func/interp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from kirin.ir import Method
2-
from kirin.interp import ReturnValue, DialectInterpreter, impl, concrete
2+
from kirin.interp import MethodTable, ReturnValue, impl, concrete
33
from kirin.dialects.func.stmts import (
44
Call,
55
Invoke,
@@ -12,7 +12,7 @@
1212

1313

1414
@dialect.register
15-
class Interpreter(DialectInterpreter):
15+
class Interpreter(MethodTable):
1616

1717
@impl(Call)
1818
def call(self, interp: concrete.Interpreter, stmt: Call, values: tuple):

0 commit comments

Comments
 (0)