11from abc import ABC
2+ from typing import Iterable
3+ from dataclasses import field , dataclass
24
35from kirin import ir
6+ from kirin .interp import AbstractFrame
47from kirin .analysis import Forward , ForwardFrame
58
69from .lattice import ErrorType
710
811ValidationFrame = ForwardFrame [ErrorType ]
912
1013
14+ @dataclass
1115class ValidationAnalysis (Forward [ErrorType ], ABC ):
1216 """Analysis pass that indicates errors in the IR according to the respective method tables.
1317
@@ -17,9 +21,43 @@ class ValidationAnalysis(Forward[ErrorType], ABC):
1721
1822 lattice = ErrorType
1923
24+ additional_errors : list [ErrorType ] = field (default_factory = list )
25+ """List to store return values that are not associated with an SSA Value (e.g. when the statement has no ResultValue)"""
26+
2027 def run_method (self , method : ir .Method , args : tuple [ErrorType , ...]):
21- return self .run_callable (method .code , (self .lattice .bottom (),) + args )
28+ return self .run_callable (method .code , (self .lattice .top (),) + args )
2229
2330 def eval_stmt_fallback (self , frame : ValidationFrame , stmt : ir .Statement ):
2431 # NOTE: default to no errors
2532 return (self .lattice .top (),)
33+
34+ def set_values (
35+ self ,
36+ frame : AbstractFrame [ErrorType ],
37+ ssa : Iterable [ir .SSAValue ],
38+ results : Iterable [ErrorType ],
39+ ):
40+ """Set the abstract values for the given SSA values in the frame.
41+
42+ This method is overridden to account for additional errors we may
43+ encounter when they are not associated to an SSA Value.
44+ """
45+
46+ number_of_ssa_values = 0
47+ for ssa_value , result in zip (ssa , results ):
48+ number_of_ssa_values += 1
49+ if ssa_value in frame .entries :
50+ frame .entries [ssa_value ] = frame .entries [ssa_value ].join (result )
51+ else :
52+ frame .entries [ssa_value ] = result
53+
54+ if isinstance (results , tuple ):
55+ # NOTE: usually what we have
56+ self .additional_errors .extend (results [number_of_ssa_values :])
57+
58+ for i , result in enumerate (results ):
59+ # NOTE: only sure-fire way I found to get remaining values from an Iterable
60+ if i < number_of_ssa_values :
61+ continue
62+
63+ self .additional_errors .append (result )
0 commit comments