Skip to content

Commit dba233e

Browse files
committed
OWTreeLearner: report error instead of crashing when can't binarize
1 parent 8bc04e1 commit dba233e

File tree

2 files changed

+79
-0
lines changed

2 files changed

+79
-0
lines changed

Orange/widgets/classify/owclassificationtree.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
from Orange.data import Table
44
from Orange.modelling.tree import TreeLearner
5+
from Orange.classification.tree import TreeLearner as ClassificationTreeLearner
56
from Orange.widgets.model.owtree import OWTreeLearner
67
from Orange.widgets.utils.owlearnerwidget import OWBaseLearner
8+
from Orange.widgets.widget import Msg
79

810

911
class OWTreeLearner(OWTreeLearner):
@@ -21,6 +23,29 @@ class OWTreeLearner(OWTreeLearner):
2123
"limit_majority", "sufficient_majority", 51, 100),) + \
2224
OWTreeLearner.spin_boxes[-1:]
2325

26+
class Error(OWTreeLearner.Error):
27+
cannot_binarize = Msg("Binarization cannot handle '{}'\n"
28+
"because it has {} values. "
29+
"Binarization can handle up to {}.\n"
30+
"Disable 'Induce binary tree' to proceed.")
31+
32+
def check_data(self):
33+
self.Error.cannot_binarize.clear()
34+
if not super().check_data():
35+
return False
36+
if not self.binary_trees:
37+
return True
38+
max_values, max_attr = max(
39+
((len(attr.values), attr)
40+
for attr in self.data.domain.attributes if attr.is_discrete),
41+
default=(0, None))
42+
MAX_BINARIZATION = ClassificationTreeLearner.MAX_BINARIZATION
43+
if max_values > MAX_BINARIZATION:
44+
self.Error.cannot_binarize(
45+
max_attr.name, max_values, MAX_BINARIZATION)
46+
return False
47+
return True
48+
2449
def learner_kwargs(self):
2550
opts = super().learner_kwargs()
2651
opts['sufficient_majority'] = \

Orange/widgets/classify/tests/test_owclassificationtree.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,19 @@
11
# Test methods with long descriptive names can omit docstrings
22
# pylint: disable=missing-docstring
3+
from Orange.data import Table, Domain, DiscreteVariable
4+
from Orange.classification.tree import TreeLearner as ClassificationTreeLearner
35
from Orange.base import Model
46
from Orange.widgets.classify.owclassificationtree import OWTreeLearner
57
from Orange.widgets.tests.base import (WidgetTest, DefaultParameterMapping,
68
ParameterMapping, WidgetLearnerTestMixin)
79

810

911
class TestOWClassificationTree(WidgetTest, WidgetLearnerTestMixin):
12+
@classmethod
13+
def setUpClass(cls):
14+
super().setUpClass()
15+
cls.iris = Table("iris")
16+
1017
def setUp(self):
1118
self.widget = self.create_widget(
1219
OWTreeLearner, stored_settings={"auto_apply": False})
@@ -34,3 +41,50 @@ def test_parameters_unchecked(self):
3441
for par, val in zip(self.parameters, (None, 2, 1))]
3542
self.test_parameters()
3643

44+
def test_cannot_binarize(self):
45+
widget = self.widget
46+
error_shown = widget.Error.cannot_binarize.is_shown
47+
self.assertFalse(error_shown())
48+
self.send_signal("Data", self.iris)
49+
50+
# The widget outputs ClassificationTreeLearner.
51+
# If not, below tests may not make sense
52+
learner = self.get_output("Learner")
53+
dlearner = learner.get_learner(learner.CLASSIFICATION)
54+
self.assertTrue(dlearner, ClassificationTreeLearner)
55+
56+
# No error on Iris
57+
max_binarization = dlearner.MAX_BINARIZATION
58+
self.assertFalse(error_shown())
59+
60+
# Error when too many values
61+
domain = Domain([
62+
DiscreteVariable(
63+
values=[str(x) for x in range(max_binarization + 1)])],
64+
DiscreteVariable(values="01"))
65+
self.send_signal("Data", Table(domain, [[0, 0], [1, 1]]))
66+
self.assertTrue(error_shown())
67+
# No more error on Iris
68+
self.send_signal("Data", self.iris)
69+
self.assertFalse(error_shown())
70+
71+
# Checking and unchecking binarization works
72+
widget.controls.binary_trees.click()
73+
self.assertFalse(widget.binary_trees)
74+
widget.unconditional_apply()
75+
self.send_signal("Data", Table(domain, [[0, 0], [1, 1]]))
76+
self.assertFalse(error_shown())
77+
widget.controls.binary_trees.click()
78+
widget.unconditional_apply()
79+
self.assertTrue(error_shown())
80+
widget.controls.binary_trees.click()
81+
widget.unconditional_apply()
82+
self.assertFalse(error_shown())
83+
84+
# If something is wrong with the data, no error appears
85+
domain = Domain([
86+
DiscreteVariable(
87+
values=[str(x) for x in range(max_binarization + 1)])],
88+
DiscreteVariable(values="01"))
89+
self.send_signal("Data", Table(domain))
90+
self.assertFalse(error_shown())

0 commit comments

Comments
 (0)