1- from typing import Iterable , Callable , List , Optional
1+ import ast
2+ from typing import Callable , List , Optional , Union , Dict , Tuple , Any
23
34import numpy as np
45from scipy .optimize import curve_fit
56
67from Orange .data import Table , Domain , ContinuousVariable , StringVariable
78from Orange .data .filter import HasClass
9+ from Orange .data .util import sanitized_name , get_unique_names
810from Orange .preprocess import RemoveNaNColumns , Impute
911from Orange .regression import Learner , Model
1012
1113__all__ = ["CurveFitLearner" ]
1214
1315
1416class CurveFitModel (Model ):
15- def __init__ (self , domain : Domain , original_domain : Domain ,
16- parameters_names : List [str ],
17- parameters : np .ndarray , function : Callable ):
17+ def __init__ (
18+ self ,
19+ domain : Domain ,
20+ original_domain : Domain ,
21+ parameters_names : List [str ],
22+ parameters : np .ndarray ,
23+ function : Optional [Callable ],
24+ create_lambda_args : Optional [Tuple ]
25+ ):
1826 super ().__init__ (domain , original_domain )
1927 self .__parameters_names = parameters_names
2028 self .__parameters = parameters
29+
30+ if function is None and create_lambda_args is not None :
31+ function , names , _ = _create_lambda (** create_lambda_args )
32+ assert parameters_names == names
33+
34+ assert function
35+
2136 self .__function = function
37+ self .__create_lambda_args = create_lambda_args
2238
2339 @property
2440 def coefficients (self ) -> Table :
@@ -34,41 +50,219 @@ def predict(self, X: np.ndarray) -> np.ndarray:
3450 return np .full (len (X ), predicted )
3551 return predicted .flatten ()
3652
53+ def __getstate__ (self ) -> Dict :
54+ if not self .__create_lambda_args :
55+ raise AttributeError (
56+ "Can't pickle/copy callable. Use str expression instead."
57+ )
58+ return {
59+ "domain" : self .domain ,
60+ "original_domain" : self .original_domain ,
61+ "parameters_names" : self .__parameters_names ,
62+ "parameters" : self .__parameters ,
63+ "function" : None ,
64+ "args" : self .__create_lambda_args ,
65+ }
66+
67+ def __setstate__ (self , state : Dict ):
68+ self .__init__ (* state .values ())
69+
3770
3871class CurveFitLearner (Learner ):
3972 preprocessors = [HasClass (), RemoveNaNColumns (), Impute ()]
4073 __returns__ = CurveFitModel
4174 name = "Curve Fit"
4275
43- def __init__ (self , function : Callable , parameters_names : List [str ],
44- feature_names : List [str ], p0 : Optional [Iterable ] = None ,
45- bounds : Iterable = (- np .inf , np .inf ), preprocessors = None ):
76+ def __init__ (
77+ self ,
78+ expression : Union [Callable , ast .Expression , str ],
79+ parameters_names : Optional [List [str ]] = None ,
80+ features_names : Optional [List [str ]] = None ,
81+ available_feature_names : Optional [List [str ]] = None ,
82+ functions : Optional [List [str ]] = None ,
83+ globals_ : Optional [Dict [str , Any ]] = None ,
84+ p0 : Union [List , Dict , None ] = None ,
85+ bounds : Union [Tuple , Dict ] = (- np .inf , np .inf ),
86+ preprocessors = None
87+ ):
4688 super ().__init__ (preprocessors )
4789
48- if not callable (function ):
49- raise TypeError ("Function is not callable." )
90+ if callable (expression ):
91+ if parameters_names is None :
92+ raise TypeError ("Provide 'parameters_names' parameter." )
93+ if features_names is None :
94+ raise TypeError ("Provide 'features_names' parameter." )
95+
96+ args = None
97+ function = expression
98+ else :
99+ if available_feature_names is None :
100+ raise TypeError ("Provide 'available_feature_names' parameter." )
101+ if functions is None :
102+ raise TypeError ("Provide 'functions' parameter." )
103+
104+ args = dict (expression = expression ,
105+ available_feature_names = available_feature_names ,
106+ functions = functions , globals_ = globals_ )
107+ function , parameters_names , features_names = _create_lambda (** args )
108+
109+ if isinstance (p0 , dict ):
110+ p0 = [p0 .get (p , 1 ) for p in parameters_names ]
111+ if isinstance (bounds , dict ):
112+ d = [- np .inf , np .inf ]
113+ lower_bounds = [bounds .get (p , d )[0 ] for p in parameters_names ]
114+ upper_bounds = [bounds .get (p , d )[1 ] for p in parameters_names ]
115+ bounds = lower_bounds , upper_bounds
50116
51117 self .__function = function
52118 self .__parameters_names = parameters_names
53- self .__feature_names = feature_names
119+ self .__features_names = features_names
54120 self .__p0 = p0
55121 self .__bounds = bounds
56122
123+ # learner is not picklable, if expression is lambda
124+ # properties, needed for pickling
125+ self .__create_lambda_args = args
126+
127+ @property
128+ def parameters_names (self ) -> List [str ]:
129+ return self .__parameters_names
130+
57131 def fit_storage (self , data : Table ) -> CurveFitModel :
58- domain = data .domain
132+ domain : Domain = data .domain
59133 attributes = []
60134 for attr in domain .attributes :
61- if attr .name in self .__feature_names :
135+ if attr .name in self .__features_names :
62136 if not attr .is_continuous :
63137 raise ValueError ("Numeric feature expected." )
64138 attributes .append (attr )
65139
66140 new_domain = Domain (attributes , domain .class_vars , domain .metas )
67141 transformed = data .transform (new_domain )
68- params , _ = curve_fit (self .__function , transformed .X , transformed .Y ,
69- p0 = self .__p0 , bounds = self .__bounds )
142+ params = curve_fit (self .__function , transformed .X , transformed .Y ,
143+ p0 = self .__p0 , bounds = self .__bounds )[ 0 ]
70144 return CurveFitModel (new_domain , domain ,
71- self .__parameters_names , params , self .__function )
145+ self .__parameters_names , params , self .__function ,
146+ self .__create_lambda_args )
147+
148+ def __getstate__ (self ) -> Dict :
149+ if not self .__create_lambda_args :
150+ raise AttributeError (
151+ "Can't pickle/copy callable. Use str expression instead."
152+ )
153+ state = self .__create_lambda_args .copy ()
154+ state ["parameters_names" ] = None
155+ state ["features_names" ] = None
156+ state ["p0" ] = self .__p0
157+ state ["bounds" ] = self .__bounds
158+ state ["preprocessors" ] = self .preprocessors
159+ return state
160+
161+ def __setstate__ (self , state : Dict ):
162+ expression = state .pop ("expression" )
163+ self .__init__ (expression , ** state )
164+
165+
166+ def _create_lambda (
167+ expression : Union [str , ast .Expression ] = "" ,
168+ available_feature_names : List [str ] = None ,
169+ functions : List [str ] = None ,
170+ globals_ : Optional [Dict [str , Any ]] = None
171+ ) -> Tuple [Callable , List [str ], List [str ]]:
172+ if globals_ is None :
173+ globals_ = {name : getattr (np , name ) for name in functions }
174+
175+ sanitized_vars_names = [sanitized_name (name ) for name
176+ in available_feature_names ]
177+
178+ exp = ast .parse (expression , mode = "eval" )
179+ search = _ParametersSearch (sanitized_vars_names , functions )
180+ search .visit (exp )
181+ params = search .parameters
182+ used_sanitized_vars_names = search .variables
183+
184+ name = sanitized_name (get_unique_names (params , "x" ))
185+ vars_mapper = {var : i for i , var in enumerate (used_sanitized_vars_names )}
186+ exp = _ReplaceVars (name , vars_mapper , functions ).visit (exp )
187+
188+ lambda_ = ast .Lambda (
189+ args = ast .arguments (
190+ posonlyargs = [],
191+ args = [ast .arg (arg = arg ) for arg in [name ] + params ],
192+ varargs = None ,
193+ kwonlyargs = [],
194+ kw_defaults = [],
195+ defaults = [],
196+ ),
197+ body = exp .body
198+ )
199+ exp = ast .Expression (body = lambda_ )
200+ ast .fix_missing_locations (exp )
201+ vars_ = [name for name in available_feature_names
202+ if sanitized_name (name ) in used_sanitized_vars_names ]
203+
204+ # pylint: disable=eval-used
205+ return eval (compile (exp , "<lambda>" , mode = "eval" ), globals_ ), params , vars_
206+
207+
208+ class _ParametersSearch (ast .NodeVisitor ):
209+ """
210+ Find features and parameters:
211+ - feature: if node is instance of ast.Name and is included in vars_names
212+ - parameters: if node is instance of ast.Name and is not included
213+ in functions
214+ """
215+
216+ def __init__ (self , vars_names : List [str ], functions : List [str ]):
217+ super ().__init__ ()
218+ self .__vars_names = vars_names
219+ self .__functions = functions
220+ self .__parameters : List [str ] = []
221+ self .__variables : List [str ] = []
222+
223+ @property
224+ def parameters (self ) -> List [str ]:
225+ return self .__parameters
226+
227+ @property
228+ def variables (self ) -> List [str ]:
229+ return self .__variables
230+
231+ def visit_Name (self , node : ast .Name ) -> ast .Name :
232+ if node .id in self .__vars_names :
233+ # don't use Set in order to preserve parameters order
234+ if node .id not in self .__variables :
235+ self .__variables .append (node .id )
236+ elif node .id not in self .__functions :
237+ # don't use Set in order to preserve parameters order
238+ if node .id not in self .__parameters :
239+ self .__parameters .append (node .id )
240+ return node
241+
242+
243+ class _ReplaceVars (ast .NodeTransformer ):
244+ """
245+ Replace feature names with X[:, i], where i is index of feature.
246+ """
247+
248+ def __init__ (self , name : str , vars_mapper : Dict , functions : List ):
249+ super ().__init__ ()
250+ self .__name = name
251+ self .__vars_mapper = vars_mapper
252+ self .__functions = functions
253+
254+ def visit_Name (self , node : ast .Name ) -> Union [ast .Name , ast .Subscript ]:
255+ if node .id not in self .__vars_mapper or node .id in self .__functions :
256+ return node
257+ else :
258+ n = self .__vars_mapper [node .id ]
259+ return ast .Subscript (
260+ value = ast .Name (id = self .__name , ctx = ast .Load ()),
261+ slice = ast .ExtSlice (
262+ dims = [ast .Slice (lower = None , upper = None , step = None ),
263+ ast .Index (value = ast .Num (n = n ))]),
264+ ctx = node .ctx
265+ )
72266
73267
74268if __name__ == "__main__" :
0 commit comments