Skip to content

Commit 70cf3d1

Browse files
committed
Curve Fit: Picklable model
1 parent 264ac56 commit 70cf3d1

File tree

5 files changed

+430
-228
lines changed

5 files changed

+430
-228
lines changed

Orange/data/util.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,3 +268,19 @@ def get_unique_names_domain(attributes, class_vars=(), metas=()):
268268
for old, new in zip(all_names, unique_names)
269269
if new != old))
270270
return (attributes, class_vars, metas), renamed
271+
272+
273+
def sanitized_name(name: str) -> str:
274+
"""
275+
Replace non-alphanumeric characters and leading zero with `_`.
276+
277+
Args:
278+
name (str): proposed name
279+
280+
Returns:
281+
name (str): new name
282+
"""
283+
sanitized = re.sub(r"\W", "_", name)
284+
if sanitized[0].isdigit():
285+
sanitized = "_" + sanitized
286+
return sanitized

Orange/regression/curvefit.py

Lines changed: 209 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,40 @@
1-
from typing import Iterable, Callable, List, Optional
1+
import ast
2+
from typing import Callable, List, Optional, Union, Dict, Tuple, Any
23

34
import numpy as np
45
from scipy.optimize import curve_fit
56

67
from Orange.data import Table, Domain, ContinuousVariable, StringVariable
78
from Orange.data.filter import HasClass
9+
from Orange.data.util import sanitized_name, get_unique_names
810
from Orange.preprocess import RemoveNaNColumns, Impute
911
from Orange.regression import Learner, Model
1012

1113
__all__ = ["CurveFitLearner"]
1214

1315

1416
class CurveFitModel(Model):
15-
def __init__(self, domain: Domain, original_domain: Domain,
16-
parameters_names: List[str],
17-
parameters: np.ndarray, function: Callable):
17+
def __init__(
18+
self,
19+
domain: Domain,
20+
original_domain: Domain,
21+
parameters_names: List[str],
22+
parameters: np.ndarray,
23+
function: Optional[Callable],
24+
create_lambda_args: Optional[Tuple]
25+
):
1826
super().__init__(domain, original_domain)
1927
self.__parameters_names = parameters_names
2028
self.__parameters = parameters
29+
30+
if function is None and create_lambda_args is not None:
31+
function, names, _ = _create_lambda(**create_lambda_args)
32+
assert parameters_names == names
33+
34+
assert function
35+
2136
self.__function = function
37+
self.__create_lambda_args = create_lambda_args
2238

2339
@property
2440
def coefficients(self) -> Table:
@@ -34,41 +50,219 @@ def predict(self, X: np.ndarray) -> np.ndarray:
3450
return np.full(len(X), predicted)
3551
return predicted.flatten()
3652

53+
def __getstate__(self) -> Dict:
54+
if not self.__create_lambda_args:
55+
raise AttributeError(
56+
"Can't pickle/copy callable. Use str expression instead."
57+
)
58+
return {
59+
"domain": self.domain,
60+
"original_domain": self.original_domain,
61+
"parameters_names": self.__parameters_names,
62+
"parameters": self.__parameters,
63+
"function": None,
64+
"args": self.__create_lambda_args,
65+
}
66+
67+
def __setstate__(self, state: Dict):
68+
self.__init__(*state.values())
69+
3770

3871
class CurveFitLearner(Learner):
3972
preprocessors = [HasClass(), RemoveNaNColumns(), Impute()]
4073
__returns__ = CurveFitModel
4174
name = "Curve Fit"
4275

43-
def __init__(self, function: Callable, parameters_names: List[str],
44-
feature_names: List[str], p0: Optional[Iterable] = None,
45-
bounds: Iterable = (-np.inf, np.inf), preprocessors=None):
76+
def __init__(
77+
self,
78+
expression: Union[Callable, ast.Expression, str],
79+
parameters_names: Optional[List[str]] = None,
80+
features_names: Optional[List[str]] = None,
81+
available_feature_names: Optional[List[str]] = None,
82+
functions: Optional[List[str]] = None,
83+
globals_: Optional[Dict[str, Any]] = None,
84+
p0: Union[List, Dict, None] = None,
85+
bounds: Union[Tuple, Dict] = (-np.inf, np.inf),
86+
preprocessors=None
87+
):
4688
super().__init__(preprocessors)
4789

48-
if not callable(function):
49-
raise TypeError("Function is not callable.")
90+
if callable(expression):
91+
if parameters_names is None:
92+
raise TypeError("Provide 'parameters_names' parameter.")
93+
if features_names is None:
94+
raise TypeError("Provide 'features_names' parameter.")
95+
96+
args = None
97+
function = expression
98+
else:
99+
if available_feature_names is None:
100+
raise TypeError("Provide 'available_feature_names' parameter.")
101+
if functions is None:
102+
raise TypeError("Provide 'functions' parameter.")
103+
104+
args = dict(expression=expression,
105+
available_feature_names=available_feature_names,
106+
functions=functions, globals_=globals_)
107+
function, parameters_names, features_names = _create_lambda(**args)
108+
109+
if isinstance(p0, dict):
110+
p0 = [p0.get(p, 1) for p in parameters_names]
111+
if isinstance(bounds, dict):
112+
d = [-np.inf, np.inf]
113+
lower_bounds = [bounds.get(p, d)[0] for p in parameters_names]
114+
upper_bounds = [bounds.get(p, d)[1] for p in parameters_names]
115+
bounds = lower_bounds, upper_bounds
50116

51117
self.__function = function
52118
self.__parameters_names = parameters_names
53-
self.__feature_names = feature_names
119+
self.__features_names = features_names
54120
self.__p0 = p0
55121
self.__bounds = bounds
56122

123+
# learner is not picklable, if expression is lambda
124+
# properties, needed for pickling
125+
self.__create_lambda_args = args
126+
127+
@property
128+
def parameters_names(self) -> List[str]:
129+
return self.__parameters_names
130+
57131
def fit_storage(self, data: Table) -> CurveFitModel:
58-
domain = data.domain
132+
domain: Domain = data.domain
59133
attributes = []
60134
for attr in domain.attributes:
61-
if attr.name in self.__feature_names:
135+
if attr.name in self.__features_names:
62136
if not attr.is_continuous:
63137
raise ValueError("Numeric feature expected.")
64138
attributes.append(attr)
65139

66140
new_domain = Domain(attributes, domain.class_vars, domain.metas)
67141
transformed = data.transform(new_domain)
68-
params, _ = curve_fit(self.__function, transformed.X, transformed.Y,
69-
p0=self.__p0, bounds=self.__bounds)
142+
params = curve_fit(self.__function, transformed.X, transformed.Y,
143+
p0=self.__p0, bounds=self.__bounds)[0]
70144
return CurveFitModel(new_domain, domain,
71-
self.__parameters_names, params, self.__function)
145+
self.__parameters_names, params, self.__function,
146+
self.__create_lambda_args)
147+
148+
def __getstate__(self) -> Dict:
149+
if not self.__create_lambda_args:
150+
raise AttributeError(
151+
"Can't pickle/copy callable. Use str expression instead."
152+
)
153+
state = self.__create_lambda_args.copy()
154+
state["parameters_names"] = None
155+
state["features_names"] = None
156+
state["p0"] = self.__p0
157+
state["bounds"] = self.__bounds
158+
state["preprocessors"] = self.preprocessors
159+
return state
160+
161+
def __setstate__(self, state: Dict):
162+
expression = state.pop("expression")
163+
self.__init__(expression, **state)
164+
165+
166+
def _create_lambda(
167+
expression: Union[str, ast.Expression] = "",
168+
available_feature_names: List[str] = None,
169+
functions: List[str] = None,
170+
globals_: Optional[Dict[str, Any]] = None
171+
) -> Tuple[Callable, List[str], List[str]]:
172+
if globals_ is None:
173+
globals_ = {name: getattr(np, name) for name in functions}
174+
175+
sanitized_vars_names = [sanitized_name(name) for name
176+
in available_feature_names]
177+
178+
exp = ast.parse(expression, mode="eval")
179+
search = _ParametersSearch(sanitized_vars_names, functions)
180+
search.visit(exp)
181+
params = search.parameters
182+
used_sanitized_vars_names = search.variables
183+
184+
name = sanitized_name(get_unique_names(params, "x"))
185+
vars_mapper = {var: i for i, var in enumerate(used_sanitized_vars_names)}
186+
exp = _ReplaceVars(name, vars_mapper, functions).visit(exp)
187+
188+
lambda_ = ast.Lambda(
189+
args=ast.arguments(
190+
posonlyargs=[],
191+
args=[ast.arg(arg=arg) for arg in [name] + params],
192+
varargs=None,
193+
kwonlyargs=[],
194+
kw_defaults=[],
195+
defaults=[],
196+
),
197+
body=exp.body
198+
)
199+
exp = ast.Expression(body=lambda_)
200+
ast.fix_missing_locations(exp)
201+
vars_ = [name for name in available_feature_names
202+
if sanitized_name(name) in used_sanitized_vars_names]
203+
204+
# pylint: disable=eval-used
205+
return eval(compile(exp, "<lambda>", mode="eval"), globals_), params, vars_
206+
207+
208+
class _ParametersSearch(ast.NodeVisitor):
209+
"""
210+
Find features and parameters:
211+
- feature: if node is instance of ast.Name and is included in vars_names
212+
- parameters: if node is instance of ast.Name and is not included
213+
in functions
214+
"""
215+
216+
def __init__(self, vars_names: List[str], functions: List[str]):
217+
super().__init__()
218+
self.__vars_names = vars_names
219+
self.__functions = functions
220+
self.__parameters: List[str] = []
221+
self.__variables: List[str] = []
222+
223+
@property
224+
def parameters(self) -> List[str]:
225+
return self.__parameters
226+
227+
@property
228+
def variables(self) -> List[str]:
229+
return self.__variables
230+
231+
def visit_Name(self, node: ast.Name) -> ast.Name:
232+
if node.id in self.__vars_names:
233+
# don't use Set in order to preserve parameters order
234+
if node.id not in self.__variables:
235+
self.__variables.append(node.id)
236+
elif node.id not in self.__functions:
237+
# don't use Set in order to preserve parameters order
238+
if node.id not in self.__parameters:
239+
self.__parameters.append(node.id)
240+
return node
241+
242+
243+
class _ReplaceVars(ast.NodeTransformer):
244+
"""
245+
Replace feature names with X[:, i], where i is index of feature.
246+
"""
247+
248+
def __init__(self, name: str, vars_mapper: Dict, functions: List):
249+
super().__init__()
250+
self.__name = name
251+
self.__vars_mapper = vars_mapper
252+
self.__functions = functions
253+
254+
def visit_Name(self, node: ast.Name) -> Union[ast.Name, ast.Subscript]:
255+
if node.id not in self.__vars_mapper or node.id in self.__functions:
256+
return node
257+
else:
258+
n = self.__vars_mapper[node.id]
259+
return ast.Subscript(
260+
value=ast.Name(id=self.__name, ctx=ast.Load()),
261+
slice=ast.ExtSlice(
262+
dims=[ast.Slice(lower=None, upper=None, step=None),
263+
ast.Index(value=ast.Num(n=n))]),
264+
ctx=node.ctx
265+
)
72266

73267

74268
if __name__ == "__main__":

0 commit comments

Comments
 (0)