Skip to content

Commit 6065e3a

Browse files
committed
OWTreeLearner: report error instead of crashing when can't binarize
1 parent a2d4385 commit 6065e3a

File tree

2 files changed

+49
-0
lines changed

2 files changed

+49
-0
lines changed

Orange/widgets/classify/owclassificationtree.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from Orange.widgets import gui
1010
from Orange.widgets.settings import Setting
1111
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
12+
from Orange.widgets.widget import Msg
1213

1314

1415
class OWTreeLearner(OWBaseLearner):
@@ -29,6 +30,12 @@ class OWTreeLearner(OWBaseLearner):
2930
("Limit the maximal tree depth to: ",
3031
"limit_depth", "max_depth", 1, 1000))
3132

33+
class Error(OWBaseLearner.Error):
34+
cannot_binarize = Msg("Binarization cannot handle '{}'\n"
35+
"because it has {} values. "
36+
"Binarization can handle up to {}.\n"
37+
"Disable 'Induce binary tree' to proceed.")
38+
3239
def add_main_layout(self):
3340
box = gui.vBox(self.controlArea, True)
3441
# the checkbox is put into vBox for alignemnt with other checkboxes
@@ -49,6 +56,20 @@ def learner_kwargs(self):
4956
binarize=self.binary_trees,
5057
preprocessors=self.preprocessors)
5158

59+
def check_data(self):
60+
self.Error.cannot_binarize.clear()
61+
if not super().check_data():
62+
return False
63+
max_values, max_attr = max(
64+
((len(attr.values), attr)
65+
for attr in self.data.domain.attributes if attr.is_discrete),
66+
default=(0, None))
67+
if max_values > self.LEARNER.MAX_BINARIZATION:
68+
self.Error.cannot_binarize(
69+
max_attr.name, max_values, self.LEARNER.MAX_BINARIZATION)
70+
return False
71+
return True
72+
5273
def create_learner(self):
5374
# pylint: disable=not-callable
5475
return self.LEARNER(**self.learner_kwargs())

Orange/widgets/classify/tests/test_owclassificationtree.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3+
from Orange.data import Table, Domain, DiscreteVariable
34
from Orange.base import Model
45
from Orange.widgets.classify.owclassificationtree import OWClassificationTree
56
from Orange.widgets.tests.base import (WidgetTest, DefaultParameterMapping,
67
ParameterMapping, WidgetLearnerTestMixin)
78

89

910
class TestOWClassificationTree(WidgetTest, WidgetLearnerTestMixin):
11+
@classmethod
12+
def setUpClass(cls):
13+
super().setUpClass()
14+
cls.iris = Table("iris")
15+
1016
def setUp(self):
1117
self.widget = self.create_widget(OWClassificationTree,
1218
stored_settings={"auto_apply": False})
@@ -34,3 +40,25 @@ def test_parameters_unchecked(self):
3440
for par, val in zip(self.parameters, (None, 2, 1))]
3541
self.test_parameters()
3642

43+
def test_cannot_binarize(self):
44+
widget = self.widget
45+
error_shown = self.widget.Error.cannot_binarize.is_shown
46+
self.assertFalse(error_shown())
47+
self.send_signal("Data", self.iris)
48+
self.assertFalse(error_shown())
49+
domain = Domain([
50+
DiscreteVariable(
51+
values=[str(x)
52+
for x in range(widget.LEARNER.MAX_BINARIZATION + 1)])],
53+
DiscreteVariable(values="01"))
54+
self.send_signal("Data", Table(domain, [[0, 0], [1, 1]]))
55+
self.assertTrue(error_shown())
56+
self.send_signal("Data", self.iris)
57+
self.assertFalse(error_shown())
58+
domain = Domain([
59+
DiscreteVariable(
60+
values=[str(x)
61+
for x in range(widget.LEARNER.MAX_BINARIZATION + 1)])],
62+
DiscreteVariable(values="01"))
63+
self.send_signal("Data", Table(domain))
64+
self.assertFalse(error_shown())

0 commit comments

Comments
 (0)