Skip to content

Commit 0b0e58e

Browse files
committed
add predecessor and union types to lattice
1 parent d133226 commit 0b0e58e

File tree

3 files changed

+78
-5
lines changed

3 files changed

+78
-5
lines changed

src/kirin/analysis/const/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
Value as Value,
1414
Bottom as Bottom,
1515
Result as Result,
16+
Union as Union,
17+
Predecessor as Predecessor,
1618
Unknown as Unknown,
1719
PartialConst as PartialConst,
1820
PartialTuple as PartialTuple,

src/kirin/analysis/const/lattice.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,3 +264,69 @@ def is_structurally_equal(
264264
for x, y in zip(self.captured, other.captured)
265265
)
266266
)
267+
268+
@final
269+
@dataclass
270+
class Predecessor(Result):
271+
"""Predecessor block in CFG."""
272+
273+
block: ir.Block
274+
value: Result
275+
276+
def __hash__(self) -> int:
277+
return id(self)
278+
279+
def is_subseteq(self, other: Result) -> bool:
280+
if isinstance(other, Predecessor):
281+
return self.value.is_subseteq(other.value)
282+
else:
283+
return self.value.is_subseteq(other)
284+
285+
def join(self, other: Result) -> Result:
286+
if isinstance(other, Predecessor):
287+
if self.is_subseteq(other):
288+
return other.value
289+
elif other.is_subseteq(self):
290+
return self.value
291+
else:
292+
return Union(predecessors=frozenset({self, other}))
293+
elif isinstance(other, Union):
294+
return other.join(self)
295+
else:
296+
return self.value.join(other)
297+
298+
def meet(self, other: Result) -> Result:
299+
if isinstance(other, Predecessor):
300+
if self.is_subseteq(other):
301+
return self.value
302+
elif other.is_subseteq(self):
303+
return other.value
304+
else:
305+
return self.bottom()
306+
elif isinstance(other, Union):
307+
return other.meet(self)
308+
else:
309+
return self.value.meet(other)
310+
311+
@final
312+
@dataclass
313+
class Union(Result):
314+
315+
predecessors: frozenset[Predecessor]
316+
317+
def join(self, other: Result) -> Result:
318+
if isinstance(other, Union):
319+
union_preds = self.predecessors.union(other.predecessors)
320+
return Union(predecessors=union_preds)
321+
elif isinstance(other, Predecessor):
322+
union_preds = self.predecessors.union({other})
323+
return Union(predecessors=union_preds)
324+
325+
def meet(self, other: Result) -> Result:
326+
if isinstance(other, Union):
327+
common_preds = self.predecessors.intersection(other.predecessors)
328+
return Union(predecessors=common_preds)
329+
elif isinstance(other, Predecessor):
330+
common_preds = self.predecessors.intersection({other})
331+
return Union(predecessors=common_preds)
332+

src/kirin/dialects/cf/constprop.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@ class ConstPropMethodTable(MethodTable):
1010
@impl(Branch)
1111
def branch(self, interp: const.Propagate, frame: const.Frame, stmt: Branch):
1212
interp.state.current_frame.worklist.append(
13-
Successor(stmt.successor, *frame.get_values(stmt.arguments))
13+
Successor(stmt.successor, *(const.Predecessor(stmt.parent_block, arg)
14+
for arg in frame.get_values(stmt.arguments)))
1415
)
1516
return ()
1617

@@ -25,10 +26,12 @@ def conditional_branch(
2526
cond = frame.get(stmt.cond)
2627
if isinstance(cond, const.Value):
2728
else_successor = Successor(
28-
stmt.else_successor, *frame.get_values(stmt.else_arguments)
29+
stmt.else_successor, *(const.Predecessor(stmt.parent_block, arg)
30+
for arg in frame.get_values(stmt.else_arguments))
2931
)
3032
then_successor = Successor(
31-
stmt.then_successor, *frame.get_values(stmt.then_arguments)
33+
stmt.then_successor, *(const.Predecessor(stmt.parent_block, arg)
34+
for arg in frame.get_values(stmt.then_arguments))
3235
)
3336
if cond.data:
3437
frame.worklist.append(then_successor)
@@ -37,13 +40,15 @@ def conditional_branch(
3740
else:
3841
frame.entries[stmt.cond] = const.Value(True)
3942
then_successor = Successor(
40-
stmt.then_successor, *frame.get_values(stmt.then_arguments)
43+
stmt.then_successor, *(const.Predecessor(stmt.parent_block, arg)
44+
for arg in frame.get_values(stmt.then_arguments))
4145
)
4246
frame.worklist.append(then_successor)
4347

4448
frame.entries[stmt.cond] = const.Value(False)
4549
else_successor = Successor(
46-
stmt.else_successor, *frame.get_values(stmt.else_arguments)
50+
stmt.else_successor, *(const.Predecessor(stmt.parent_block, arg)
51+
for arg in frame.get_values(stmt.else_arguments))
4752
)
4853
frame.worklist.append(else_successor)
4954

0 commit comments

Comments
 (0)