Skip to content

Commit 5f3d39e

Browse files
authored
Beer doc with Analysis and Codegen example (#224)
1 parent fe09644 commit 5f3d39e

File tree

3 files changed

+444
-0
lines changed

3 files changed

+444
-0
lines changed
Lines changed: 307 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,307 @@
1+
## Beer price/fee analysis
2+
3+
In this section we will discuss on how to perform analysis of a kirin program. We will again use our `beer` dialect example.
4+
5+
### Goal
6+
7+
Let's Consider the following program
8+
```python
9+
@beer
10+
def main2(x: int):
11+
12+
bud = NewBeer(brand="budlight")
13+
heineken = NewBeer(brand="heineken")
14+
15+
bud_pints = Pour(bud, 12 + x)
16+
heineken_pints = Pour(heineken, 10 + x)
17+
18+
Drink(bud_pints)
19+
Drink(heineken_pints)
20+
Puke()
21+
22+
Drink(bud_pints)
23+
Puke()
24+
25+
Drink(bud_pints)
26+
Puke()
27+
28+
return x
29+
```
30+
31+
We would like to implement an forward dataflow analysis that walk through the program, and collect the price information of each statements.
32+
33+
### Define Lattice
34+
One of the important concept related to doing static analysis is the *Lattice* (See [Wiki:Lattice](https://en.wikipedia.org/wiki/Lattice_(order)) and [Lecture Note On Static Analysis](https://studwww.itu.dk/~brabrand/static.pdf) for further details)
35+
A Lattice defines the partial order of the lattice element. An simple example is the type lattice.
36+
37+
Let's now defines our `Item` lattice for the price analysis.
38+
39+
First, a lattice always has top and bottom elements. In type lattice, the top element is `Any` and bottom element is `None`.
40+
41+
42+
Here, we define `AnyItem` as top and `NoItem` as bottom. In kirin, we can simply inherit the `BoundedLattice` from `kirin.lattice`. Kirin also provide some simple mixin with default implementation of the API such as `is_subseteq`, `join` and `meet` so you don't have to re-implement them.
43+
44+
```python
45+
from kirin.lattice import (
46+
SingletonMeta,
47+
BoundedLattice,
48+
IsSubsetEqMixin,
49+
SimpleJoinMixin,
50+
SimpleMeetMixin,
51+
)
52+
53+
@dataclass
54+
class Item(
55+
IsSubsetEqMixin["Item"],
56+
SimpleJoinMixin["Item"],
57+
SimpleMeetMixin["Item"],
58+
BoundedLattice["Item"],
59+
):
60+
61+
@classmethod
62+
def top(cls) -> "Item":
63+
return AnyItem()
64+
65+
@classmethod
66+
def bottom(cls) -> "Item":
67+
return NotItem()
68+
69+
70+
@final
71+
@dataclass
72+
class NotItem(Item, metaclass=SingletonMeta): # (1)!
73+
"""The bottom of the lattice.
74+
75+
Since the element is the same without any field,
76+
we can use the SingletonMeta to make it a singleton by inherit the metaclass
77+
78+
"""
79+
80+
def is_subseteq(self, other: Item) -> bool:
81+
return True
82+
83+
84+
@final
85+
@dataclass
86+
class AnyItem(Item, metaclass=SingletonMeta):
87+
"""The top of the lattice.
88+
89+
Since the element is the same without any field,
90+
we can use the SingletonMeta to make it a singleton by inherit the metaclass
91+
92+
"""
93+
94+
def is_subseteq(self, other: Item) -> bool:
95+
return isinstance(other, AnyItem)
96+
97+
```
98+
99+
1. Notice that since `NotItem` and `AnyItem` does not have any properties, we can mark them as singleton to remove duplication copy of instances exist by inheriting `SingletonMeta` metaclass
100+
101+
Next there are a few more lattice elements we want to define:
102+
103+
```python
104+
@final
105+
@dataclass
106+
class ItemPints(Item): # (1)!
107+
count: Item
108+
brand: str
109+
110+
def is_subseteq(self, other: Item) -> bool:
111+
return (
112+
isinstance(other, ItemPints)
113+
and self.count == other.count
114+
and self.brand == other.brand
115+
)
116+
117+
@final
118+
@dataclass
119+
class AtLeastXItem(Item): # (2)!
120+
data: int
121+
122+
def is_subseteq(self, other: Item) -> bool:
123+
return isinstance(other, AtLeastXItem) and self.data == other.data
124+
125+
126+
@final
127+
@dataclass
128+
class ConstIntItem(Item):
129+
data: int
130+
131+
def is_subseteq(self, other: Item) -> bool:
132+
return isinstance(other, ConstIntItem) and self.data == other.data
133+
134+
135+
@final
136+
@dataclass
137+
class ItemBeer(Item):
138+
brand: str
139+
140+
def is_subseteq(self, other: Item) -> bool:
141+
return isinstance(other, ItemBeer) and self.brand == other.brand
142+
143+
144+
```
145+
146+
1. `ItemPints` which contain information of the beer brand of `Pints`, as well as the count
147+
2. `AtLeastXItem` which contain information of a constant type result value is a number that is least `x`. The `data` contain the lower-bound
148+
3. `ConstIntItem` which contain concrete number.
149+
4. `ItemBeer` which contain information of `Beer`.
150+
151+
152+
### Custom Forward Data Flow Analysis
153+
154+
Now we have defined our lattice. Let's move on to see how we can write an analysis pass, and get the analysis results.
155+
156+
In kirin, the analysis pass is implemented with `AbstractInterpreter` (inspired by Julia). Kirin provides an simple forward dataflow analysis `Forward`. So we will use that.
157+
158+
Here our analysis want to do two things.
159+
160+
1. Get all the analysis results as dictionary of SSAVAlue to lattice element.
161+
2. Count how many time one puke.
162+
163+
```python
164+
@dataclass
165+
class FeeAnalysis(Forward[latt.Item]): # (1)!
166+
keys = ["beer.fee"] # (2)!
167+
lattice = latt.Item
168+
puke_count: int = field(init=False)
169+
170+
def initialize(self): # (3)!
171+
"""Initialize the analysis pass.
172+
173+
The method is called before the analysis pass starts.
174+
175+
Note:
176+
1. Here one is *required* to call the super().initialize() to initialize the analysis pass,
177+
which clear all the previous analysis results and symbol tables.
178+
2. Any additional initialization that belongs to the analysis should also be done here.
179+
For example, in this case, we initialize the puke_count to 0.
180+
181+
"""
182+
super().initialize()
183+
self.puke_count = 0
184+
return self
185+
186+
def eval_stmt_fallback( # (4)!
187+
self, frame: ForwardFrame[latt.Item, None], stmt: ir.Statement
188+
) -> tuple[latt.Item, ...] | interp.SpecialValue[latt.Item]:
189+
return ()
190+
191+
def run_method(self, method: ir.Method, args: tuple[latt.Item, ...]) -> latt.Item: # (5)!
192+
return self.run_callable(method.code, (self.lattice.bottom(),) + args)
193+
194+
```
195+
196+
1. Interit from `Forward` with our customize lattice `Item`.
197+
2. The keys for the MethodTable. Remember that in kirin all the implmentation methods of a interpreter is implmeneted and registered to a `MethodTable`.
198+
3. `AbstractInterpreter` has a abstract method `initialize` which will be called everytime `run()` is called. We can overload this to reset the variable, so we can re-use this class instance.
199+
4. This method implement the *fallback* when interprete a statement that does not have implmenetation found in the method table. Here, we just return an empty tuple because we know all the statements that has return value will be implemented, so only statements without return value will be fallbacked.
200+
5. This method defines and customize how to run the `ir.Method`.
201+
202+
Click the + logo to see more details.
203+
204+
Now we want to implement how the statement gets run. This is the same as what we described when we mentioned the concrete interpreter.
205+
206+
Note that each dialect can have multiple registered `MethodTable`, distinguished by `key`. The interpreter use `key` to find corresponding `MethodTable`s for each dialect in a dialect group.
207+
208+
Here, we use `key="beer.fee"`
209+
210+
First we need to implement for `Constant` statement in `py.constant` dialect. If its `int`, we return `ConstIntItem` lattice element. If its `Beer`, we return `ItemBeer`.
211+
212+
```python
213+
@py.constant.dialect.register(key="beer.fee")
214+
class PyConstMethodTable(interp.MethodTable):
215+
216+
@interp.impl(py.constant.Constant)
217+
def const(
218+
self,
219+
interp: FeeAnalysis,
220+
frame: interp.Frame[latt.Item],
221+
stmt: py.constant.Constant,
222+
):
223+
if isinstance(stmt.value, int):
224+
return (latt.ConstIntItem(data=stmt.value),)
225+
elif isinstance(stmt.value, Beer):
226+
return (latt.ItemBeer(brand=stmt.value.brand),)
227+
228+
else:
229+
raise exceptions.InterpreterError(
230+
f"illegal constant type {type(stmt.value)}"
231+
)
232+
```
233+
234+
235+
Next, since we allow `add` in the program, we also need to let abstract interpreter know how to interprete `binop.Add` statement, which is in `py.binop` dialect.
236+
```python
237+
@binop.dialect.register(key="beer.fee")
238+
class PyBinOpMethodTable(interp.MethodTable):
239+
240+
@interp.impl(binop.Add)
241+
def add(
242+
self,
243+
interp: FeeAnalysis,
244+
frame: interp.Frame[latt.Item],
245+
stmt: binop.Add,
246+
):
247+
left = frame.get(stmt.lhs)
248+
right = frame.get(stmt.rhs)
249+
250+
if isinstance(left, latt.AtLeastXItem) or isinstance(right, latt.AtLeastXItem):
251+
out = latt.AtLeastXItem(data=left.data + right.data)
252+
else:
253+
out = latt.ConstIntItem(data=left.data + right.data)
254+
255+
return (out,)
256+
```
257+
258+
Finally, we need implementation for our beer dialect Statements.
259+
```python
260+
@dialect.register(key="beer.fee")
261+
class BeerMethodTable(interp.MethodTable):
262+
263+
@interp.impl(NewBeer)
264+
def new_beer(
265+
self,
266+
interp: FeeAnalysis,
267+
frame: interp.Frame[latt.Item],
268+
stmt: NewBeer,
269+
):
270+
return (latt.ItemBeer(brand=stmt.brand),)
271+
272+
@interp.impl(Pour)
273+
def pour(
274+
self,
275+
interp: FeeAnalysis,
276+
frame: interp.Frame[latt.Item],
277+
stmt: Pour,
278+
):
279+
# Drink depends on the beer type to have different charge:
280+
281+
beer: latt.ItemBeer = frame.get(stmt.beverage)
282+
pint_count: latt.AtLeastXItem | latt.ConstIntItem = frame.get(stmt.amount)
283+
284+
out = latt.ItemPints(count=pint_count, brand=beer.brand)
285+
286+
return (out,)
287+
288+
@interp.impl(Puke)
289+
def puke(
290+
self,
291+
interp: FeeAnalysis,
292+
frame: interp.Frame[latt.Item],
293+
stmt: Puke,
294+
):
295+
interp.puke_count += 1
296+
return ()
297+
298+
```
299+
300+
## Put it together:
301+
302+
```python
303+
fee_analysis = FeeAnalysis(main2.dialects)
304+
results, expect = fee_analysis.run_analysis(main2, args=(AtLeastXItem(data=10),))
305+
print(results)
306+
print(fee_analysis.puke_count)
307+
```

0 commit comments

Comments
 (0)