Skip to content

Commit 5c98904

Browse files
committed
Curve Fit: Minor fixes
1 parent 7d81784 commit 5c98904

File tree

5 files changed

+69
-14
lines changed

5 files changed

+69
-14
lines changed

Orange/data/tests/test_util.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from Orange.data import Domain, ContinuousVariable
55
from Orange.data.util import get_unique_names, get_unique_names_duplicates, \
6-
get_unique_names_domain, one_hot
6+
get_unique_names_domain, one_hot, sanitized_name
77

88

99
class TestGetUniqueNames(unittest.TestCase):
@@ -260,5 +260,13 @@ def test_dim_too_low(self):
260260
one_hot(self.values, dim=2)
261261

262262

263+
class TestSanitizedName(unittest.TestCase):
264+
def test_sanitized_name(self):
265+
self.assertEqual(sanitized_name("Foo"), "Foo")
266+
self.assertEqual(sanitized_name("Foo Bar"), "Foo_Bar")
267+
self.assertEqual(sanitized_name("0Foo"), "_0Foo")
268+
self.assertEqual(sanitized_name("1 Foo Bar"), "_1_Foo_Bar")
269+
270+
263271
if __name__ == "__main__":
264272
unittest.main()

Orange/regression/curvefit.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def coefficients(self) -> Table:
4545

4646
def predict(self, X: np.ndarray) -> np.ndarray:
4747
predicted = self.__function(X, *self.__parameters)
48-
if isinstance(predicted, float):
48+
if not isinstance(predicted, np.ndarray):
4949
# handle constant function; i.e. len(self.domain.attributes) == 0
50-
return np.full(len(X), predicted)
50+
return np.full(len(X), predicted, dtype=float)
5151
return predicted.flatten()
5252

5353
def __getstate__(self) -> Dict:
@@ -261,6 +261,8 @@ def _create_lambda(
261261
262262
Examples
263263
--------
264+
>>> from Orange.data import Table
265+
>>> data = Table("housing")
264266
>>> sfun = "a * exp(-b * CRIM * LSTAT) + c"
265267
>>> names = [a.name for a in data.domain.attributes]
266268
>>> func, par, var = _create_lambda(sfun, available_feature_names=names,

Orange/regression/tests/test_curvefit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def test_init_str(self):
104104

105105
self.assertRaises(TypeError, CurveFitLearner, "a + b")
106106

107+
kw = dict(available_feature_names=[])
108+
self.assertRaises(TypeError, CurveFitLearner, "a + b", **kw)
109+
107110
def test_init_ast(self):
108111
kw = dict(available_feature_names=[], functions=[])
109112
exp = ast.parse("a + b", mode="eval")
@@ -119,6 +122,9 @@ def test_init_callable(self):
119122

120123
self.assertRaises(TypeError, CurveFitLearner, lambda x, a: a)
121124

125+
kw = dict(parameters_names=[])
126+
self.assertRaises(TypeError, CurveFitLearner, lambda x, a: a, **kw)
127+
122128
def test_fit(self):
123129
learner = CurveFitLearner(func, [], ["CRIM"])
124130
model = learner(self.data)

Orange/widgets/model/owcurvefit.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"fmod", "gcd", "hypot", "isfinite", "isinf", "isnan", "ldexp",
3030
"log", "log10", "log1p", "log2", "pi", "power", "radians",
3131
"remainder", "sin", "sinh", "sqrt", "tan", "tanh", "trunc",
32-
"round", "abs")}
32+
"round", "abs", "any", "all")}
3333

3434

3535
class Parameter:
@@ -260,6 +260,7 @@ class Warning(OWBaseLearner.Warning):
260260
duplicate_parameter = Msg("Duplicated parameter name.")
261261
unused_parameter = Msg("Unused parameter '{}' in "
262262
"'Parameters' declaration.")
263+
data_missing = Msg("Provide data on the input.")
263264

264265
class Error(OWBaseLearner.Error):
265266
invalid_exp = Msg("Invalid expression.")
@@ -305,6 +306,8 @@ def __init__(self, *args, **kwargs):
305306

306307
super().__init__(*args, **kwargs)
307308

309+
self.Warning.data_missing()
310+
308311
def add_main_layout(self):
309312
box = gui.vBox(self.controlArea, "Parameters")
310313
self.__param_widget = ParametersWidget(self)
@@ -391,12 +394,9 @@ def __insert_into_expression(self, what: str, offset=0):
391394
self.__expression_edit.setFocus()
392395

393396
def set_data(self, data: Optional[Table]):
397+
self.Warning.data_missing(shown=not bool(data))
394398
super().set_data(data)
395399
self.__clear()
396-
# self.__init_models()
397-
# self.__enable_controls()
398-
# self.__set_pending()
399-
# self.unconditional_apply()
400400

401401
def __clear(self):
402402
self.expression = ""
@@ -452,10 +452,11 @@ def create_learner(self) -> Optional[CurveFitLearner]:
452452
self.Error.no_parameter.clear()
453453
self.Error.unknown_parameter.clear()
454454
self.Warning.unused_parameter.clear()
455-
if not self.__pp_data or not self.expression:
455+
expression = self.expression.strip()
456+
if not self.__pp_data or not expression:
456457
return None
457458

458-
if not self.__validate_expression(self.expression):
459+
if not self.__validate_expression(expression):
459460
self.Error.invalid_exp()
460461
return None
461462

@@ -467,7 +468,7 @@ def create_learner(self) -> Optional[CurveFitLearner]:
467468
param.upper if param.use_upper else np.inf)
468469

469470
learner = self.LEARNER(
470-
self.expression,
471+
expression,
471472
available_feature_names=[a.name for a in self.__feature_model[1:]],
472473
functions=FUNCTIONS,
473474
sanitizer=sanitized_name,

Orange/widgets/model/tests/test_owcurvefit.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,29 @@ def test_functions(self):
2222
func = getattr(np, f)
2323
if isinstance(func, float):
2424
pass
25+
elif f in ["any", "all"]:
26+
self.assertTrue(func(a))
2527
elif f in ["arctan2", "copysign", "fmod", "gcd", "hypot",
2628
"isclose", "ldexp", "power", "remainder"]:
2729
self.assertIsInstance(func(a, 2), np.ndarray)
30+
self.assertEqual(func(a, 2).shape, (5,))
2831
else:
2932
self.assertIsInstance(func(a), np.ndarray)
33+
self.assertEqual(func(a).shape, (5,))
34+
35+
36+
class TestParameter(unittest.TestCase):
37+
def test_to_tuple(self):
38+
args = ("foo", 2, True, 10, False, 50)
39+
par = Parameter(*args)
40+
self.assertEqual(par.to_tuple(), args)
41+
42+
def test_repr(self):
43+
args = ("foo", 2, True, 10, False, 50)
44+
par = Parameter(*args)
45+
str_par = "Parameter(name=foo, initial=2, use_lower=True, " \
46+
"lower=10, use_upper=False, upper=50)"
47+
self.assertEqual(str(par), str_par)
3048

3149

3250
class TestParametersWidget(WidgetTest):
@@ -66,6 +84,19 @@ def test_add_row(self):
6684
self.assertFalse(data.use_upper)
6785
self.assertEqual(data.upper, 100)
6886

87+
def test_remove(self):
88+
n = 5
89+
for _ in range(n):
90+
self._widget._add_row()
91+
self.assertEqual(len(self._widget._ParametersWidget__data), n)
92+
93+
k = 2
94+
for _ in range(k):
95+
button = self._widget._ParametersWidget__controls[0][0]
96+
button.click()
97+
98+
self.assertEqual(len(self._widget._ParametersWidget__data), n - k)
99+
69100
def test_add_row_with_data(self):
70101
param = Parameter("a", 3, True, 2, False, 4)
71102
self._widget._add_row(param)
@@ -165,6 +196,13 @@ def test_input_data_learner_adequacy(self): # overwritten
165196
self.wait_until_stop_blocking()
166197
self.assertFalse(self.widget.Error.data_error.is_shown())
167198

199+
def test_input_data_missing(self):
200+
self.assertTrue(self.widget.Warning.data_missing.is_shown())
201+
self.send_signal(self.widget.Inputs.data, self.housing)
202+
self.assertFalse(self.widget.Warning.data_missing.is_shown())
203+
self.send_signal(self.widget.Inputs.data, None)
204+
self.assertTrue(self.widget.Warning.data_missing.is_shown())
205+
168206
def test_input_preprocessor(self):
169207
self.__init_widget()
170208
super().test_input_preprocessor()
@@ -282,18 +320,18 @@ def test_parameters_combo(self):
282320
def test_function_combo(self):
283321
combo = self.widget.controls._function
284322
model = combo.model()
285-
self.assertEqual(model.rowCount(), 44)
323+
self.assertEqual(model.rowCount(), 46)
286324
self.assertEqual(combo.currentText(), "Select Function")
287325

288326
self.send_signal(self.widget.Inputs.data, self.housing)
289-
self.assertEqual(model.rowCount(), 44)
327+
self.assertEqual(model.rowCount(), 46)
290328
self.assertEqual(combo.currentText(), "Select Function")
291329
simulate.combobox_activate_index(combo, 1)
292330
self.assertEqual(self.widget._OWCurveFit__expression_edit.text(),
293331
"abs()")
294332

295333
self.send_signal(self.widget.Inputs.data, None)
296-
self.assertEqual(model.rowCount(), 44)
334+
self.assertEqual(model.rowCount(), 46)
297335
self.assertEqual(combo.currentText(), "Select Function")
298336

299337
def test_expression(self):

0 commit comments

Comments
 (0)