Skip to content

Commit b26e8bb

Browse files
committed
OWNomogram: Normalize probabilities
1 parent 6c958dc commit b26e8bb

File tree

1 file changed

+62
-3
lines changed

1 file changed

+62
-3
lines changed

Orange/widgets/visualize/ownomogram.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ def move_to_sum(self):
103103
return
104104
total = sum(item.value for item in self.movable_dot_items)
105105
self.move_to_val(total)
106+
if self.get_probabilities:
107+
self.parentItem().rescale()
106108

107109
def get_tooltip_text(self):
108110
value = self.value
@@ -317,11 +319,13 @@ def collides(_item):
317319
return True
318320
return False
319321

322+
self.text_items_with_values = []
320323
shown_items = []
321324
w = QGraphicsSimpleTextItem(labels[0]).boundingRect().width()
322325
text_finish = values[0] * scale - w + offset - 10
323326
for i, (label, value) in enumerate(zip(labels, values)):
324327
text = QGraphicsSimpleTextItem(label)
328+
self.text_items_with_values.append((values[i], text))
325329
x_text = value * scale - text.boundingRect().width() / 2 + offset
326330
if text_finish > x_text - 10:
327331
y_text, y_tick = self.dot_r * 0.7, -self.tick_height
@@ -347,6 +351,11 @@ def collides(_item):
347351
half_tick.setParentItem(self)
348352
old_x_tick = x_tick
349353

354+
def rescale(self):
355+
func = self.dot.get_probabilities
356+
for value, item in self.text_items_with_values:
357+
item.setText(str(np.round(func(value), 2)))
358+
350359

351360
class DiscreteFeatureItem(RulerItem):
352361
tick_height = 6
@@ -532,6 +541,7 @@ class OWNomogram(OWWidget):
532541
ACCEPTABLE = (NaiveBayesModel, LogisticRegressionClassifier)
533542
settingsHandler = DomainContextHandler()
534543
target_class_index = ContextSetting(0)
544+
normalize_probabilities = Setting(False)
535545
align = Setting(1)
536546
scale = Setting(1)
537547
display_index = Setting(0)
@@ -567,11 +577,18 @@ def __init__(self):
567577
self.vertical_line = None
568578
self.hidden_vertical_line = None
569579
self.old_target_class_index = self.target_class_index
580+
self.markers_set = False
570581

571582
# GUI
583+
box = gui.vBox(self.controlArea, "Target class")
572584
self.class_combo = gui.comboBox(
573-
self.controlArea, self, "target_class_index", "Target class",
574-
callback=self._class_combo_changed, contentsLength=12)
585+
box, self, "target_class_index", callback=self._class_combo_changed,
586+
contentsLength=12)
587+
self.norm_check = gui.checkBox(
588+
box, self, "normalize_probabilities", "Normalize probabilities",
589+
callback=self._norm_check_changed,
590+
tooltip="For multiclass data 1 vs. all probabilities do not"
591+
" sum to 1 and therefore could be normalized.")
575592

576593
self.align_radio = gui.radioButtons(
577594
self.controlArea, self, "align",
@@ -632,6 +649,11 @@ def _class_combo_changed(self):
632649
self.update_scene()
633650
self.old_target_class_index = self.target_class_index
634651

652+
def _norm_check_changed(self):
653+
values = [item.dot.value for item in self.feature_items]
654+
self.feature_marker_values = self.scale_back(values)
655+
self.update_scene()
656+
635657
def _radio_button_changed(self):
636658
values = [item.dot.value for item in self.feature_items]
637659
self.feature_marker_values = self.scale_back(values)
@@ -675,11 +697,14 @@ def eventFilter(self, obj, event):
675697

676698
def update_controls(self):
677699
self.class_combo.clear()
700+
self.norm_check.setHidden(True)
678701
self.cont_feature_dim_combo.setEnabled(True)
679702
if self.domain:
680703
self.class_combo.addItems(self.domain.class_vars[0].values)
681704
if len(self.domain.attributes) > self.MAX_N_ATTRS:
682705
self.display_index = 1
706+
if len(self.domain.class_vars[0].values) > 2:
707+
self.norm_check.setHidden(False)
683708
if not self.domain.has_continuous_attributes():
684709
self.cont_feature_dim_combo.setEnabled(False)
685710
self.cont_feature_dim_index = 0
@@ -906,13 +931,47 @@ def create_footer_nomogram(self, total_text, probs_text, d, minimums,
906931
nomogram_footer = NomogramItem()
907932
total_item = RulerItem(total_text, values, scale_x, name_offset,
908933
- scale_x * min_sum, title="Total")
934+
935+
def get_normalized_probabilities(val):
936+
if not self.normalize_probabilities:
937+
return 1 / (1 + np.exp(k[cls_index] - val / d_))
938+
totals = self.__get_totals_for_class_values(minimums)
939+
p_sum = np.sum(1 / (1 + np.exp(k - totals / d_)))
940+
return 1 / (1 + np.exp(k[cls_index] - val / d_)) / p_sum
941+
942+
self.markers_set = False
909943
probs_item = RulerItem(
910944
probs_text, values, scale_x, name_offset, - scale_x * min_sum,
911945
title="P({}='{}')".format(cls_var.name, cls_var.values[cls_index]),
912-
get_probabilities=lambda f: 1 / (1 + np.exp(k[cls_index] - f / d_)))
946+
get_probabilities=get_normalized_probabilities)
947+
self.markers_set = True
913948
nomogram_footer.add_items([total_item, probs_item])
914949
return total_item, probs_item, nomogram_footer
915950

951+
def __get_totals_for_class_values(self, minimums):
952+
cls_index = self.target_class_index
953+
marker_values = [item.dot.value for item in self.feature_items]
954+
if not self.markers_set:
955+
marker_values = self.scale_forth(marker_values)
956+
totals = np.empty(len(self.domain.class_var.values))
957+
totals[cls_index] = sum(marker_values)
958+
marker_values = self.scale_back(marker_values)
959+
for i in range(len(self.domain.class_var.values)):
960+
if i == cls_index:
961+
continue
962+
coeffs = [np.nan_to_num(p[i] / p[cls_index]) for p in self.points]
963+
points = [p[cls_index] for p in self.points]
964+
total = sum([self.get_points_from_coeffs(v, c, p) for (v, c, p)
965+
in zip(marker_values, coeffs, points)])
966+
if self.align == OWNomogram.ALIGN_LEFT:
967+
points = [p - m for m, p in zip(minimums, points)]
968+
total -= sum([min(p) for p in [p[i] for p in self.points]])
969+
d = 100 / max(max(abs(p)) for p in points)
970+
if self.scale == OWNomogram.POINT_SCALE:
971+
total *= d
972+
totals[i] = total
973+
return totals
974+
916975
def set_feature_marker_values(self):
917976
if not (len(self.points) and len(self.feature_items)):
918977
return

0 commit comments

Comments
 (0)