Skip to content

Commit 5fdb0d8

Browse files
authored
simple example of rewrite with insert of multiple stmts + codegen (#193)
This PR address #145 TODO: modify/add doc
1 parent d89b4f3 commit 5fdb0d8

File tree

6 files changed

+167
-35
lines changed

6 files changed

+167
-35
lines changed

example/beer/emit.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from io import StringIO
2+
from dataclasses import field, dataclass
3+
4+
import stmts
5+
from group import beer
6+
from dialect import dialect
7+
8+
from kirin import ir, interp
9+
from lattice import Item, ItemPints, AtLeastXItem, ConstIntItem
10+
from kirin.emit import EmitStr, EmitStrFrame
11+
from kirin.dialects import func
12+
from kirin.emit.exceptions import EmitError
13+
14+
15+
def default_menu_price():
16+
return {
17+
"budlight": 3.0,
18+
"heineken": 4.0,
19+
"tsingdao": 2.0,
20+
}
21+
22+
23+
@dataclass
24+
class EmitReceptMain(EmitStr):
25+
keys = ["emit.recept"]
26+
dialects: ir.DialectGroup = field(default=beer)
27+
file: StringIO = field(default_factory=StringIO)
28+
menu_price: dict[str, float] = field(default_factory=default_menu_price)
29+
recept_analysis_result: dict[ir.SSAValue, Item] = field(default_factory=dict)
30+
31+
def initialize(self):
32+
super().initialize()
33+
self.file.truncate(0)
34+
self.file.seek(0)
35+
return self
36+
37+
def eval_stmt_fallback(
38+
self, frame: EmitStrFrame, stmt: ir.Statement
39+
) -> tuple[str, ...]:
40+
return (stmt.name,)
41+
42+
def emit_block(self, frame: EmitStrFrame, block: ir.Block) -> str | None:
43+
for stmt in block.stmts:
44+
result = self.eval_stmt(frame, stmt)
45+
if isinstance(result, tuple):
46+
frame.set_values(stmt.results, result)
47+
return None
48+
49+
def get_output(self) -> str:
50+
self.file.seek(0)
51+
return "\n".join(
52+
[
53+
"item \tamount \t price",
54+
"-----------------------------------",
55+
self.file.read(),
56+
]
57+
)
58+
59+
60+
@func.dialect.register(key="emit.recept")
61+
class FuncEmit(interp.MethodTable):
62+
63+
@interp.impl(func.Function)
64+
def emit_func(self, emit: EmitReceptMain, frame: EmitStrFrame, stmt: func.Function):
65+
_ = emit.run_ssacfg_region(frame, stmt.body)
66+
return ()
67+
68+
69+
@dialect.register(key="emit.recept")
70+
class BeerEmit(interp.MethodTable):
71+
72+
@interp.impl(stmts.Pour)
73+
def emit_pour(self, emit: EmitReceptMain, frame: EmitStrFrame, stmt: stmts.Pour):
74+
pints_item: ItemPints = emit.recept_analysis_result[stmt.result]
75+
76+
amount_str = ""
77+
price_str = ""
78+
if isinstance(pints_item.count, AtLeastXItem):
79+
amount_str = f">={pints_item.count.data}"
80+
price_str = (
81+
f" >=${emit.menu_price[pints_item.brand] * pints_item.count.data}"
82+
)
83+
elif isinstance(pints_item.count, ConstIntItem):
84+
amount_str = f" {pints_item.count.data}"
85+
price_str = (
86+
f" ${emit.menu_price[pints_item.brand] * pints_item.count.data}"
87+
)
88+
else:
89+
raise EmitError("invalid analysis result.")
90+
91+
emit.writeln(frame, f"{pints_item.brand}\t{amount_str}\t{price_str}")
92+
93+
return ()

example/beer/group.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from dialect import dialect
2+
3+
from rewrite import RandomWalkBranch
4+
from kirin.ir import dialect_group
5+
from kirin.prelude import basic_no_opt
6+
from kirin.rewrite import Walk, Fixpoint
7+
from kirin.passes.fold import Fold
8+
9+
10+
# create our own beer dialect, it runs a random walk on the branches
11+
@dialect_group(basic_no_opt.add(dialect))
12+
def beer(self):
13+
14+
fold_pass = Fold(self)
15+
16+
def run_pass(mt, *, fold=True):
17+
Fixpoint(Walk(RandomWalkBranch())).rewrite(mt.code)
18+
19+
# add const fold
20+
if fold:
21+
fold_pass(mt)
22+
23+
return run_pass

example/beer/recept.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,6 @@ def add(
8989
@dialect.register(key="beer.fee")
9090
class BeerMethodTable(interp.MethodTable):
9191

92-
menu_price: dict[str, float] = {
93-
"budlight": 1.0,
94-
"heineken": 2.0,
95-
"tsingdao": 3.0,
96-
}
97-
9892
@interp.impl(NewBeer)
9993
def new_beer(
10094
self,

example/beer/rewrite.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from dataclasses import dataclass
22

3-
from stmts import RandomBranch
3+
from stmts import Puke, Drink, NewBeer, RandomBranch
44

55
from kirin.dialects import cf
66
from kirin.rewrite.abc import RewriteRule, RewriteResult
@@ -23,3 +23,21 @@ def rewrite_Statement(self, node: Statement) -> RewriteResult:
2323
)
2424
)
2525
return RewriteResult(has_done_something=True)
26+
27+
28+
@dataclass
29+
class NewBeerAndPukeOnDrink(RewriteRule):
30+
# sometimes someone get drunk, so they keep getting new beer and puke after they drink
31+
def rewrite_Statement(self, node: Statement) -> RewriteResult:
32+
if not isinstance(node, Drink):
33+
return RewriteResult()
34+
35+
# 1. create new stmts:
36+
new_beer_stmt = NewBeer(brand="saporo")
37+
puke_stmt = Puke()
38+
39+
# 2. put them in the ir
40+
new_beer_stmt.insert_after(node)
41+
puke_stmt.insert_after(new_beer_stmt)
42+
43+
return RewriteResult(has_done_something=True)

example/beer/script.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,12 @@
1+
from group import beer
12
from stmts import Pour, Puke, Drink, NewBeer
23
from recept import FeeAnalysis
3-
from dialect import dialect
44

5+
from emit import EmitReceptMain
56
from interp import BeerMethods as BeerMethods
67
from lattice import AtLeastXItem
7-
from rewrite import RandomWalkBranch
8-
from kirin.ir import dialect_group
9-
from kirin.prelude import basic_no_opt
10-
from kirin.rewrite import Walk, Fixpoint
11-
from kirin.passes.fold import Fold
12-
13-
14-
# create our own beer dialect, it runs a random walk on the branches
15-
@dialect_group(basic_no_opt.add(dialect))
16-
def beer(self):
17-
18-
fold_pass = Fold(self)
19-
20-
def run_pass(mt, *, fold=True):
21-
Fixpoint(Walk(RandomWalkBranch())).rewrite(mt.code)
22-
23-
# add const fold
24-
if fold:
25-
fold_pass(mt)
26-
27-
return run_pass
28-
29-
30-
# we are going to get drunk!
31-
# add our beer dialect to the default dialect (builtin, cf, func, ...)
8+
from rewrite import NewBeerAndPukeOnDrink
9+
from kirin.rewrite import Walk
3210

3311

3412
# type: ignore
@@ -63,7 +41,26 @@ def some_closure(beer, amount):
6341
# main(i) # now drink a random beer!
6442

6543

66-
# simple analysis example:
44+
# 2. simple rewrite:
45+
@beer
46+
def main3():
47+
48+
bud = NewBeer(brand="budlight")
49+
heineken = NewBeer(brand="heineken")
50+
51+
bud_pints = Pour(bud, 2)
52+
heineken_pints = Pour(heineken, 10)
53+
54+
Drink(bud_pints)
55+
Drink(heineken_pints)
56+
57+
58+
main3.print()
59+
Walk(NewBeerAndPukeOnDrink()).rewrite(main3.code)
60+
main3.print()
61+
62+
63+
# 3. simple analysis example:
6764
@beer
6865
def main2(x: int):
6966

@@ -91,3 +88,10 @@ def main2(x: int):
9188
print(results)
9289
print(fee_analysis.puke_count)
9390
main2.print(analysis=results)
91+
92+
93+
emitter = EmitReceptMain()
94+
emitter.recept_analysis_result = results
95+
96+
emitter.run(main2, ("",))
97+
print(emitter.get_output())

src/kirin/interp/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __post_init__(self) -> None:
8585
self.registry = self.dialects.registry.interpreter(keys=self.keys)
8686

8787
def initialize(self) -> Self:
88-
"""Initialize the interpreter global states. This method is called before
88+
"""Initialize the interpreter global states. This method is called right upon
8989
calling [`run`][kirin.interp.base.BaseInterpreter.run] to initialize the
9090
interpreter global states.
9191

0 commit comments

Comments
 (0)