Skip to content

Commit 948d8a0

Browse files
authored
refactor analysis with SSAValue hints (#234)
## Analysis now returns frame instead of `dict` Previously we record the results into a global field `analysis.results` by merging the `frame.entries` into this dictionary. This works for simple cases, but in the case of constant prop, the purity flag is per frame (e.g when you go through one of the successor branch/region). As a result, when a function call has no return value, or a scf.For does not have yield (thus no results). We will not mark it pure. More generally, when there are extra results per function, we won't be able to return that results. So now we return the frame we use to run the analysis within, which contains everything one defines to collect for the analysis. ## Fixing DCE on SCF with the above change, we are able to fix DCE on SCF and DCE now only checks `Pure` or `MaybePure`. ## Replacing `Hint` with `SSAValue.hints` A new field `hints` is added to `SSAValue`. This is because I realize there could be different kind of hints from analysis to carry over within the IR. This also simplifies the type system making it less complicated. --- well this turns out to be quite breaking... But it simplifies things a lot. This also fixes a few bugs related to purity analysis - the purity should be per frame/statement instead of per SSAValue (well... I always felt this was strange, now it is correct). Constant prop is no longer convoluted with purity analysis (still slightly but a lot easier to decouple in the future with joint analysis). As a result loops on pure functions can fold correctly now.
1 parent e51a875 commit 948d8a0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

108 files changed

+1054
-1131
lines changed

src/kirin/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# re-exports the public API of the kirin package
22
from kirin import ir
3-
from kirin.ir import types as types
43
from kirin.decl import info, statement
54

6-
__all__ = ["ir", "statement", "info"]
5+
from . import types as types
6+
7+
__all__ = ["ir", "types", "statement", "info"]

src/kirin/analysis/const/__init__.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,13 @@
88
the IR.
99
"""
1010

11-
from .prop import Propagate as Propagate, ExtraFrameInfo as ExtraFrameInfo
11+
from .prop import Frame as Frame, Propagate as Propagate
1212
from .lattice import (
13-
Pure as Pure,
1413
Value as Value,
1514
Bottom as Bottom,
1615
Result as Result,
17-
NotPure as NotPure,
1816
Unknown as Unknown,
19-
JointResult as JointResult,
2017
PartialConst as PartialConst,
2118
PartialTuple as PartialTuple,
22-
PurityBottom as PurityBottom,
2319
PartialLambda as PartialLambda,
2420
)

src/kirin/analysis/const/lattice.py

Lines changed: 30 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,30 @@
22
"""
33

44
from typing import Any, final
5-
from dataclasses import field, dataclass
5+
from dataclasses import dataclass
66

77
from kirin import ir
88
from kirin.lattice import (
9-
LatticeMeta,
10-
SingletonMeta,
119
BoundedLattice,
1210
IsSubsetEqMixin,
1311
SimpleJoinMixin,
1412
SimpleMeetMixin,
1513
)
14+
from kirin.ir.attrs.abc import LatticeAttributeMeta, SingletonLatticeAttributeMeta
15+
from kirin.print.printer import Printer
1616

1717
from ._visitor import _ElemVisitor
1818

1919

2020
@dataclass
2121
class Result(
22+
ir.Attribute,
2223
IsSubsetEqMixin["Result"],
2324
SimpleJoinMixin["Result"],
2425
SimpleMeetMixin["Result"],
2526
BoundedLattice["Result"],
2627
_ElemVisitor,
28+
metaclass=LatticeAttributeMeta,
2729
):
2830
"""Base class for constant analysis results."""
2931

@@ -35,29 +37,38 @@ def top(cls) -> "Result":
3537
def bottom(cls) -> "Result":
3638
return Bottom()
3739

40+
def print_impl(self, printer: Printer) -> None:
41+
printer.plain_print(repr(self))
42+
3843

3944
@final
4045
@dataclass
41-
class Unknown(Result, metaclass=SingletonMeta):
46+
class Unknown(Result, metaclass=SingletonLatticeAttributeMeta):
4247
"""Unknown constant value. This is the top element of the lattice."""
4348

4449
def is_subseteq(self, other: Result) -> bool:
4550
return isinstance(other, Unknown)
4651

52+
def __hash__(self) -> int:
53+
return id(self)
54+
4755

4856
@final
4957
@dataclass
50-
class Bottom(Result, metaclass=SingletonMeta):
58+
class Bottom(Result, metaclass=SingletonLatticeAttributeMeta):
5159
"""Bottom element of the lattice."""
5260

5361
def is_subseteq(self, other: Result) -> bool:
5462
return True
5563

64+
def __hash__(self) -> int:
65+
return id(self)
66+
5667

5768
@final
5869
@dataclass
5970
class Value(Result):
60-
"""Constant value. Wraps any Python value."""
71+
"""Constant value. Wraps any hashable Python value."""
6172

6273
data: Any
6374

@@ -69,6 +80,12 @@ def is_equal(self, other: Result) -> bool:
6980
return self.data == other.data
7081
return False
7182

83+
def __hash__(self) -> int:
84+
# NOTE: we use id here because the data
85+
# may not be hashable. This is fine because
86+
# the data is guaranteed to be unique.
87+
return id(self)
88+
7289

7390
@dataclass
7491
class PartialConst(Result):
@@ -78,7 +95,7 @@ class PartialConst(Result):
7895

7996

8097
@final
81-
class PartialTupleMeta(LatticeMeta):
98+
class PartialTupleMeta(LatticeAttributeMeta):
8299
"""Metaclass for PartialTuple.
83100
84101
This metaclass canonicalizes PartialTuple instances with all Value elements
@@ -139,6 +156,9 @@ def is_subseteq_Value(self, other: Value) -> bool:
139156
return all(x.is_subseteq(Value(y)) for x, y in zip(self.data, other.data))
140157
return False
141158

159+
def __hash__(self) -> int:
160+
return hash(self.data)
161+
142162

143163
@final
144164
@dataclass
@@ -152,6 +172,9 @@ class PartialLambda(PartialConst):
152172
code: ir.Statement
153173
captured: tuple[Result, ...]
154174

175+
def __hash__(self) -> int:
176+
return hash((self.argnames, self.code, self.captured))
177+
155178
def is_subseteq_PartialLambda(self, other: "PartialLambda") -> bool:
156179
if self.code is not other.code:
157180
return False
@@ -194,82 +217,3 @@ def meet(self, other: Result) -> Result:
194217
self.code,
195218
tuple(x.meet(y) for x, y in zip(self.captured, other.captured)),
196219
)
197-
198-
199-
@dataclass(frozen=True)
200-
class Purity(
201-
SimpleJoinMixin["Purity"], SimpleMeetMixin["Purity"], BoundedLattice["Purity"]
202-
):
203-
"""Base class for purity lattice."""
204-
205-
@classmethod
206-
def bottom(cls) -> "Purity":
207-
return PurityBottom()
208-
209-
@classmethod
210-
def top(cls) -> "Purity":
211-
return NotPure()
212-
213-
214-
@dataclass(frozen=True)
215-
class Pure(Purity, metaclass=SingletonMeta):
216-
"""The result is from a pure function."""
217-
218-
def is_subseteq(self, other: Purity) -> bool:
219-
return isinstance(other, (NotPure, Pure))
220-
221-
222-
@dataclass(frozen=True)
223-
class NotPure(Purity, metaclass=SingletonMeta):
224-
"""The result is from an impure function."""
225-
226-
def is_subseteq(self, other: Purity) -> bool:
227-
return isinstance(other, NotPure)
228-
229-
230-
@dataclass(frozen=True)
231-
class PurityBottom(Purity, metaclass=SingletonMeta):
232-
"""The bottom element of the purity lattice."""
233-
234-
def is_subseteq(self, other: Purity) -> bool:
235-
return True
236-
237-
238-
@dataclass
239-
class JointResult(BoundedLattice["JointResult"]):
240-
"""Joint result of constant value and purity.
241-
242-
This lattice is used to join the constant value and purity of a function
243-
during constant propagation analysis. This allows the analysis to track
244-
both the constant value and the purity of the function, so that the analysis
245-
can propagate constant values through function calls even if the function
246-
is only partially pure.
247-
"""
248-
249-
const: Result
250-
"""The constant value of the result."""
251-
purity: Purity = field(default_factory=Purity.top)
252-
"""The purity of statement that produces the result."""
253-
254-
@classmethod
255-
def from_const(cls, value: Any) -> "JointResult":
256-
return cls(Value(value), Purity.top())
257-
258-
@classmethod
259-
def top(cls) -> "JointResult":
260-
return cls(Result.top(), Purity.top())
261-
262-
@classmethod
263-
def bottom(cls) -> "JointResult":
264-
return cls(Result.bottom(), Purity.bottom())
265-
266-
def is_subseteq(self, other: "JointResult") -> bool:
267-
return self.const.is_subseteq(other.const) and self.purity.is_subseteq(
268-
other.purity
269-
)
270-
271-
def join(self, other: "JointResult") -> "JointResult":
272-
return JointResult(self.const.join(other.const), self.purity.join(other.purity))
273-
274-
def meet(self, other: "JointResult") -> "JointResult":
275-
return JointResult(self.const.meet(other.const), self.purity.join(other.purity))

src/kirin/analysis/const/prop.py

Lines changed: 41 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,25 @@
11
from dataclasses import field, dataclass
22

3-
from kirin import ir, interp
3+
from kirin import ir, types, interp
44
from kirin.analysis.forward import ForwardExtra, ForwardFrame
55

6-
from .lattice import Pure, Value, NotPure, Unknown, JointResult
6+
from .lattice import Value, Result, Unknown
77

88

99
@dataclass
10-
class ExtraFrameInfo:
10+
class Frame(ForwardFrame[Result]):
11+
should_be_pure: set[ir.Statement] = field(default_factory=set)
12+
"""If any ir.MaybePure is actually pure."""
1113
frame_is_not_pure: bool = False
14+
"""If we hit any non-pure statement."""
1215

1316

1417
@dataclass
15-
class Propagate(ForwardExtra[JointResult, ExtraFrameInfo]):
18+
class Propagate(ForwardExtra[Frame, Result]):
1619
"""Forward dataflow analysis for constant propagation.
1720
1821
This analysis is a forward dataflow analysis that propagates constant values
19-
through the program. It uses the `JointResult` lattice to track the constant
22+
through the program. It uses the `Result` lattice to track the constant
2023
values and purity of the values.
2124
2225
The analysis is implemented as a forward dataflow analysis, where the
@@ -30,7 +33,7 @@ class Propagate(ForwardExtra[JointResult, ExtraFrameInfo]):
3033
"""
3134

3235
keys = ["constprop"]
33-
lattice = JointResult
36+
lattice = Result
3437

3538
_interp: interp.Interpreter = field(init=False)
3639

@@ -49,93 +52,65 @@ def initialize(self):
4952
self._interp.initialize()
5053
return self
5154

55+
def new_frame(self, code: ir.Statement) -> Frame:
56+
return Frame.from_func_like(code)
57+
5258
def _try_eval_const_pure(
5359
self,
54-
frame: ForwardFrame[JointResult, ExtraFrameInfo],
60+
frame: Frame,
5561
stmt: ir.Statement,
5662
values: tuple[Value, ...],
57-
) -> interp.StatementResult[JointResult]:
63+
) -> interp.StatementResult[Result]:
5864
try:
5965
_frame = self._interp.new_frame(frame.code)
6066
_frame.set_values(stmt.args, tuple(x.data for x in values))
6167
value = self._interp.eval_stmt(_frame, stmt)
6268
match value:
6369
case tuple():
64-
return tuple(JointResult(Value(each), Pure()) for each in value)
70+
return tuple(Value(each) for each in value)
6571
case interp.ReturnValue(ret):
66-
return interp.ReturnValue(JointResult(Value(ret), Pure()))
72+
return interp.ReturnValue(Value(ret))
6773
case interp.YieldValue(yields):
68-
return interp.YieldValue(
69-
tuple(JointResult(Value(each), Pure()) for each in yields)
70-
)
74+
return interp.YieldValue(tuple(Value(each) for each in yields))
7175
case interp.Successor(block, args):
7276
return interp.Successor(
7377
block,
74-
*tuple(JointResult(Value(each), Pure()) for each in args),
78+
*tuple(Value(each) for each in args),
7579
)
7680
except interp.InterpreterError:
7781
pass
7882
return (self.void,)
7983

8084
def eval_stmt(
81-
self, frame: ForwardFrame[JointResult, ExtraFrameInfo], stmt: ir.Statement
82-
) -> interp.StatementResult[JointResult]:
85+
self, frame: Frame, stmt: ir.Statement
86+
) -> interp.StatementResult[Result]:
8387
if stmt.has_trait(ir.ConstantLike):
8488
return self._try_eval_const_pure(frame, stmt, ())
8589
elif stmt.has_trait(ir.Pure):
86-
values = tuple(x.const for x in frame.get_values(stmt.args))
87-
if ir.types.is_tuple_of(values, Value):
90+
values = frame.get_values(stmt.args)
91+
if types.is_tuple_of(values, Value):
8892
return self._try_eval_const_pure(frame, stmt, values)
8993

9094
method = self.lookup_registry(frame, stmt)
91-
if method is not None:
92-
ret = method(self, frame, stmt)
93-
self._set_frame_not_pure(ret)
95+
if method is None:
96+
if stmt.has_trait(ir.Pure):
97+
return (Unknown(),) # no implementation but pure
98+
# not pure, and no implementation, let's say it's not pure
99+
frame.frame_is_not_pure = True
100+
return (Unknown(),)
101+
102+
ret = method(self, frame, stmt)
103+
if stmt.has_trait(ir.IsTerminator):
94104
return ret
95-
elif stmt.has_trait(ir.Pure):
96-
# fallback to top for other statements
97-
return (JointResult(Unknown(), Pure()),)
98-
else:
99-
if frame.extra is None:
100-
frame.extra = ExtraFrameInfo(True)
101-
return (JointResult(Unknown(), NotPure()),)
102-
103-
def _set_frame_not_pure(self, result: interp.StatementResult[JointResult]):
104-
frame = self.state.current_frame()
105-
if isinstance(result, tuple) and all(x.purity is Pure() for x in result):
106-
return
107-
108-
if isinstance(result, interp.ReturnValue) and isinstance(
109-
result.value.purity, Pure
110-
):
111-
return
112-
113-
if isinstance(result, interp.YieldValue) and all(
114-
isinstance(x.purity, Pure) for x in result
115-
):
116-
return
117-
118-
if isinstance(result, interp.Successor) and all(
119-
x.purity is Pure() for x in result.block_args
120-
):
121-
return
122-
123-
if frame.extra is None:
124-
frame.extra = ExtraFrameInfo(True)
105+
elif not stmt.has_trait(ir.MaybePure): # cannot be pure at all
106+
frame.frame_is_not_pure = True
107+
elif (
108+
stmt not in frame.should_be_pure
109+
): # implementation cannot decide if it's pure
110+
frame.frame_is_not_pure = True
111+
return ret
125112

126113
def run_method(
127-
self, method: ir.Method, args: tuple[JointResult, ...]
128-
) -> JointResult:
129-
return self.run_callable(
130-
method.code, (JointResult(Value(method), NotPure()),) + args
131-
)
132-
133-
def finalize(
134-
self,
135-
frame: ForwardFrame[JointResult, ExtraFrameInfo],
136-
results: JointResult,
137-
) -> JointResult:
138-
results = super().finalize(frame, results)
139-
if frame.extra is not None and frame.extra.frame_is_not_pure:
140-
return JointResult(results.const, NotPure())
141-
return results
114+
self, method: ir.Method, args: tuple[Result, ...]
115+
) -> tuple[Frame, Result]:
116+
return self.run_callable(method.code, (Value(method),) + args)

0 commit comments

Comments
 (0)