Skip to content

Commit b1f27c6

Browse files
committed
Predictions: allow predicting probabilities for classless data
Before, combo box for selection of class probabilities was only shown if the data on the input had a dicrete class. Now, the combo is shown as soon as there is one predictor which works with discrete class vars. If class is missing from the file, only model probabilities options are shown there.
1 parent 70452b2 commit b1f27c6

File tree

2 files changed

+55
-1
lines changed

2 files changed

+55
-1
lines changed

Orange/widgets/evaluate/owpredictions.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,16 @@ def _set_target_combos(self):
266266
self.target_class = self.TARGET_AVERAGE
267267
else:
268268
self.shown_probs = self.NO_PROBS
269+
model = prob_combo.model()
270+
for v in (self.DATA_PROBS, self.BOTH_PROBS):
271+
item = model.item(v)
272+
item.setFlags(item.flags() & ~Qt.ItemIsEnabled)
269273

270274
def _update_control_visibility(self):
271275
for widget in self._prob_controls:
272-
widget.setVisible(self.is_discrete_class)
276+
widget.setVisible(any((slot.predictor.domain.class_var is not None and
277+
slot.predictor.domain.class_var.is_discrete)
278+
for slot in self.predictors))
273279

274280
for widget in self._target_controls:
275281
widget.setVisible(self.is_discrete_class and self.show_scores)

Orange/widgets/evaluate/tests/test_owpredictions.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class TestOWPredictions(WidgetTest):
3737
def setUp(self):
3838
self.widget = self.create_widget(OWPredictions) # type: OWPredictions
3939
self.iris = Table("iris")
40+
self.iris_classless = self.iris.transform(Domain(self.iris.domain.attributes, []))
4041
self.housing = Table("housing")
4142

4243
def test_rowCount_from_model(self):
@@ -718,16 +719,32 @@ def test_update_delegates_continuous(self):
718719

719720
widget.data = Table.from_list(Domain([], ContinuousVariable("c")), [])
720721

722+
# only regression
723+
all_predictors = widget.predictors
724+
widget.predictors = [widget.predictors[-1]]
721725
widget._update_control_visibility()
722726
self.assertTrue(widget.controls.shown_probs.isHidden())
723727
self.assertTrue(widget.controls.target_class.isHidden())
724728

729+
# regression and classification
730+
widget.predictors = all_predictors
731+
widget._update_control_visibility()
732+
self.assertFalse(widget.controls.shown_probs.isHidden())
733+
self.assertTrue(widget.controls.target_class.isHidden())
734+
725735
widget._set_class_values()
726736
self.assertEqual(widget.class_values, list("abcde"))
727737

728738
widget._set_target_combos()
729739
self.assertEqual(widget.shown_probs, widget.NO_PROBS)
730740

741+
def is_enabled(prob_item):
742+
return widget.controls.shown_probs.model().item(prob_item).flags() & Qt.ItemIsEnabled
743+
self.assertTrue(is_enabled(widget.NO_PROBS))
744+
self.assertTrue(is_enabled(widget.MODEL_PROBS))
745+
self.assertFalse(is_enabled(widget.DATA_PROBS))
746+
self.assertFalse(is_enabled(widget.BOTH_PROBS))
747+
731748
widget._update_prediction_delegate()
732749
for delegate in widget._delegates:
733750
self.assertEqual(list(delegate.shown_probabilities), [])
@@ -857,6 +874,37 @@ def test_output_regression(self):
857874
out.metas,
858875
np.hstack([pred.results.predicted.T for pred in widget.predictors]))
859876

877+
def test_classless(self):
878+
widget = self.widget
879+
iris012 = self.iris
880+
purge = Remove(class_flags=Remove.RemoveUnusedValues)
881+
iris01 = purge(iris012[:100])
882+
iris12 = purge(iris012[50:])
883+
884+
bayes01 = NaiveBayesLearner()(iris01)
885+
bayes12 = NaiveBayesLearner()(iris12)
886+
bayes012 = NaiveBayesLearner()(iris012)
887+
888+
self.send_signal(widget.Inputs.data, self.iris_classless)
889+
self.send_signal(widget.Inputs.predictors, bayes01, 0)
890+
self.send_signal(widget.Inputs.predictors, bayes12, 1)
891+
self.send_signal(widget.Inputs.predictors, bayes012, 2)
892+
893+
for i, pred in enumerate(widget.predictors):
894+
p = pred.results.unmapped_probabilities
895+
p[0] = 10 + 100 * i + np.arange(p.shape[1])
896+
pred.results.unmapped_predicted[:] = i
897+
898+
widget.shown_probs = widget.NO_PROBS
899+
widget._commit_predictions()
900+
out = self.get_output(widget.Outputs.predictions)
901+
self.assertEqual(list(out.metas[0]), [0, 1, 2])
902+
903+
widget.shown_probs = widget.MODEL_PROBS
904+
widget._commit_predictions()
905+
out = self.get_output(widget.Outputs.predictions)
906+
self.assertEqual(list(out.metas[0]), [0, 10, 11, 1, 110, 111, 2, 210, 211, 212])
907+
860908
@patch("Orange.widgets.evaluate.owpredictions.usable_scorers",
861909
Mock(return_value=[_Scorer]))
862910
def test_change_target(self):

0 commit comments

Comments
 (0)