Skip to content

Commit 44424ad

Browse files
committed
Curve Fit: Minor fixes
1 parent ea0d068 commit 44424ad

File tree

5 files changed

+58
-14
lines changed

5 files changed

+58
-14
lines changed

Orange/data/tests/test_util.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from Orange.data import Domain, ContinuousVariable
55
from Orange.data.util import get_unique_names, get_unique_names_duplicates, \
6-
get_unique_names_domain, one_hot
6+
get_unique_names_domain, one_hot, sanitized_name
77

88

99
class TestGetUniqueNames(unittest.TestCase):
@@ -260,5 +260,13 @@ def test_dim_too_low(self):
260260
one_hot(self.values, dim=2)
261261

262262

263+
class TestSanitizedName(unittest.TestCase):
264+
def test_sanitized_name(self):
265+
self.assertEqual(sanitized_name("Foo"), "Foo")
266+
self.assertEqual(sanitized_name("Foo Bar"), "Foo_Bar")
267+
self.assertEqual(sanitized_name("0Foo"), "_0Foo")
268+
self.assertEqual(sanitized_name("1 Foo Bar"), "_1_Foo_Bar")
269+
270+
263271
if __name__ == "__main__":
264272
unittest.main()

Orange/regression/curvefit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ def coefficients(self) -> Table:
4545

4646
def predict(self, X: np.ndarray) -> np.ndarray:
4747
predicted = self.__function(X, *self.__parameters)
48-
if isinstance(predicted, float):
48+
if not isinstance(predicted, np.ndarray):
4949
# handle constant function; i.e. len(self.domain.attributes) == 0
50-
return np.full(len(X), predicted)
50+
return np.full(len(X), predicted, dtype=float)
5151
return predicted.flatten()
5252

5353
def __getstate__(self) -> Dict:

Orange/regression/tests/test_curvefit.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,9 @@ def test_init_str(self):
104104

105105
self.assertRaises(TypeError, CurveFitLearner, "a + b")
106106

107+
kw = dict(available_feature_names=[])
108+
self.assertRaises(TypeError, CurveFitLearner, "a + b", **kw)
109+
107110
def test_init_ast(self):
108111
kw = dict(available_feature_names=[], functions=[])
109112
exp = ast.parse("a + b", mode="eval")
@@ -119,6 +122,9 @@ def test_init_callable(self):
119122

120123
self.assertRaises(TypeError, CurveFitLearner, lambda x, a: a)
121124

125+
kw = dict(parameters_names=[])
126+
self.assertRaises(TypeError, CurveFitLearner, lambda x, a: a, **kw)
127+
122128
def test_fit(self):
123129
learner = CurveFitLearner(func, [], ["CRIM"])
124130
model = learner(self.data)

Orange/widgets/model/owcurvefit.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"fmod", "gcd", "hypot", "isfinite", "isinf", "isnan", "ldexp",
3030
"log", "log10", "log1p", "log2", "pi", "power", "radians",
3131
"remainder", "sin", "sinh", "sqrt", "tan", "tanh", "trunc",
32-
"round", "abs")}
32+
"round", "abs", "any", "all")}
3333

3434

3535
class Parameter:
@@ -393,10 +393,6 @@ def __insert_into_expression(self, what: str, offset=0):
393393
def set_data(self, data: Optional[Table]):
394394
super().set_data(data)
395395
self.__clear()
396-
# self.__init_models()
397-
# self.__enable_controls()
398-
# self.__set_pending()
399-
# self.unconditional_apply()
400396

401397
def __clear(self):
402398
self.expression = ""
@@ -452,10 +448,11 @@ def create_learner(self) -> Optional[CurveFitLearner]:
452448
self.Error.no_parameter.clear()
453449
self.Error.unknown_parameter.clear()
454450
self.Warning.unused_parameter.clear()
455-
if not self.__pp_data or not self.expression:
451+
expression = self.expression.strip()
452+
if not self.__pp_data or not expression:
456453
return None
457454

458-
if not self.__validate_expression(self.expression):
455+
if not self.__validate_expression(expression):
459456
self.Error.invalid_exp()
460457
return None
461458

@@ -467,7 +464,7 @@ def create_learner(self) -> Optional[CurveFitLearner]:
467464
param.upper if param.use_upper else np.inf)
468465

469466
learner = self.LEARNER(
470-
self.expression,
467+
expression,
471468
available_feature_names=[a.name for a in self.__feature_model[1:]],
472469
functions=FUNCTIONS,
473470
sanitizer=sanitized_name,

Orange/widgets/model/tests/test_owcurvefit.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,29 @@ def test_functions(self):
2222
func = getattr(np, f)
2323
if isinstance(func, float):
2424
pass
25+
elif f in ["any", "all"]:
26+
self.assertTrue(func(a))
2527
elif f in ["arctan2", "copysign", "fmod", "gcd", "hypot",
2628
"isclose", "ldexp", "power", "remainder"]:
2729
self.assertIsInstance(func(a, 2), np.ndarray)
30+
self.assertEqual(func(a, 2).shape, (5,))
2831
else:
2932
self.assertIsInstance(func(a), np.ndarray)
33+
self.assertEqual(func(a).shape, (5,))
34+
35+
36+
class TestParameter(unittest.TestCase):
37+
def test_to_tuple(self):
38+
args = ("foo", 2, True, 10, False, 50)
39+
par = Parameter(*args)
40+
self.assertEqual(par.to_tuple(), args)
41+
42+
def test_repr(self):
43+
args = ("foo", 2, True, 10, False, 50)
44+
par = Parameter(*args)
45+
str_par = "Parameter(name=foo, initial=2, use_lower=True, " \
46+
"lower=10, use_upper=False, upper=50)"
47+
self.assertEqual(str(par), str_par)
3048

3149

3250
class TestParametersWidget(WidgetTest):
@@ -66,6 +84,21 @@ def test_add_row(self):
6684
self.assertFalse(data.use_upper)
6785
self.assertEqual(data.upper, 100)
6886

87+
def test_remove(self):
88+
n = 5
89+
for _ in range(n):
90+
self._widget._add_row()
91+
self.assertEqual(len(self._widget._ParametersWidget__data), n)
92+
93+
k = 2
94+
for _ in range(k):
95+
button = self._widget._ParametersWidget__controls[0][0]
96+
button.click()
97+
98+
self.assertEqual(len(self._widget._ParametersWidget__data), n - k)
99+
100+
101+
69102
def test_add_row_with_data(self):
70103
param = Parameter("a", 3, True, 2, False, 4)
71104
self._widget._add_row(param)
@@ -282,18 +315,18 @@ def test_parameters_combo(self):
282315
def test_function_combo(self):
283316
combo = self.widget.controls._function
284317
model = combo.model()
285-
self.assertEqual(model.rowCount(), 44)
318+
self.assertEqual(model.rowCount(), 46)
286319
self.assertEqual(combo.currentText(), "Select Function")
287320

288321
self.send_signal(self.widget.Inputs.data, self.housing)
289-
self.assertEqual(model.rowCount(), 44)
322+
self.assertEqual(model.rowCount(), 46)
290323
self.assertEqual(combo.currentText(), "Select Function")
291324
simulate.combobox_activate_index(combo, 1)
292325
self.assertEqual(self.widget._OWCurveFit__expression_edit.text(),
293326
"abs()")
294327

295328
self.send_signal(self.widget.Inputs.data, None)
296-
self.assertEqual(model.rowCount(), 44)
329+
self.assertEqual(model.rowCount(), 46)
297330
self.assertEqual(combo.currentText(), "Select Function")
298331

299332
def test_expression(self):

0 commit comments

Comments
 (0)