Skip to content

Commit aa9a09c

Browse files
authored
Merge pull request #1870 from pavlin-policar/treeviewer-adapter
[FIX] Treeviewer sklearn tree compatibility
2 parents 35e6188 + 19357a7 commit aa9a09c

File tree

5 files changed

+107
-48
lines changed

5 files changed

+107
-48
lines changed

Orange/widgets/visualize/owtreeviewer.py

Lines changed: 61 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from AnyQt.QtGui import QColor, QBrush, QPen, QFontMetrics
99
from AnyQt.QtCore import Qt, QPointF, QSizeF, QRectF
1010

11-
from Orange.tree import TreeModel
11+
from Orange.base import TreeModel, SklModel
1212
from Orange.widgets.visualize.owtreeviewer2d import \
1313
GraphicsNode, GraphicsEdge, OWTreeViewer2D
1414
from Orange.widgets.utils import to_html
@@ -20,6 +20,8 @@
2020
from Orange.widgets.utils.colorpalette import ContinuousPaletteGenerator
2121
from Orange.widgets.utils.annotated_data import (create_annotated_table,
2222
ANNOTATED_DATA_SIGNAL_NAME)
23+
from Orange.widgets.visualize.utils.tree.skltreeadapter import SklTreeAdapter
24+
from Orange.widgets.visualize.utils.tree.treeadapter import TreeAdapter
2325

2426

2527
class PieChart(QGraphicsRectItem):
@@ -66,20 +68,21 @@ class TreeNode(GraphicsNode):
6668
# Methods are documented in PyQt documentation
6769
# pylint: disable=missing-docstring
6870

69-
def __init__(self, model, node_inst, parent=None):
71+
def __init__(self, tree_adapter, node_inst, parent=None):
7072
super().__init__(parent)
71-
self.model = model
73+
self.tree_adapter = tree_adapter
74+
self.model = self.tree_adapter.model
7275
self.node_inst = node_inst
7376

7477
fm = QFontMetrics(self.document().defaultFont())
75-
attr = node_inst.attr
78+
attr = self.tree_adapter.attribute(node_inst)
7679
self.attr_text_w = fm.width(attr.name if attr else "")
7780
self.attr_text_h = fm.lineSpacing()
7881
self.line_descent = fm.descent()
7982
self._rect = None
8083

81-
if model.domain.class_var.is_discrete:
82-
self.pie = PieChart(node_inst.value, 8, self)
84+
if self.model.domain.class_var.is_discrete:
85+
self.pie = PieChart(self.tree_adapter.get_distribution(node_inst)[0], 8, self)
8386
else:
8487
self.pie = None
8588

@@ -90,7 +93,7 @@ def update_contents(self):
9093
self.droplet.setPos(self.rect().center().x(), self.rect().height())
9194
self.droplet.setVisible(bool(self.branches))
9295
fm = QFontMetrics(self.document().defaultFont())
93-
attr = self.node_inst.attr
96+
attr = self.tree_adapter.attribute(self.node_inst)
9497
self.attr_text_w = fm.width(attr.name if attr else "")
9598
self.attr_text_h = fm.lineSpacing()
9699
self.line_descent = fm.descent()
@@ -129,7 +132,7 @@ def paint(self, painter, option, widget=None):
129132
font = self.document().defaultFont()
130133
painter.setFont(font)
131134
if self.parent:
132-
draw_text = self.node_inst.description
135+
draw_text = str(self.tree_adapter.short_rule(self.node_inst))
133136
if self.parent.x() > self.x(): # node is to the left
134137
fm = QFontMetrics(font)
135138
x = rect.width() / 2 - fm.width(draw_text) - 4
@@ -140,7 +143,7 @@ def paint(self, painter, option, widget=None):
140143
painter.setBrush(self.backgroundBrush)
141144
painter.setPen(QPen(Qt.black, 3 if self.isSelected() else 0))
142145
adjrect = rect.adjusted(-3, 0, 0, 0)
143-
if not self.node_inst.children:
146+
if not self.tree_adapter.has_children(self.node_inst):
144147
painter.drawRoundedRect(adjrect, 4, 4)
145148
else:
146149
painter.drawRect(adjrect)
@@ -188,6 +191,7 @@ def __init__(self):
188191
self.domain = None
189192
self.dataset = None
190193
self.clf_dataset = None
194+
self.tree_adapter = None
191195

192196
self.color_label = QLabel("Target class: ")
193197
combo = self.color_combo = gui.OrangeComboBox()
@@ -211,9 +215,8 @@ def set_node_info(self):
211215
node.set_rect(QRectF(rect.x(), rect.y(), w, rect.height()))
212216
self.scene.fix_pos(self.root_node, 10, 10)
213217

214-
@staticmethod
215-
def _update_node_info_attr_name(node, text):
216-
attr = node.node_inst.attr
218+
def _update_node_info_attr_name(self, node, text):
219+
attr = self.tree_adapter.attribute(node.node_inst)
217220
if attr is not None:
218221
text += "<hr/>{}".format(attr.name)
219222
return text
@@ -263,7 +266,9 @@ def ctree(self, model=None):
263266
self.info.setText('No tree.')
264267
self.root_node = None
265268
self.dataset = None
269+
self.tree_adapter = None
266270
else:
271+
self.tree_adapter = self._get_tree_adapter(model)
267272
self.domain = model.domain
268273
self.dataset = model.instances
269274
if self.dataset is not None and self.dataset.domain != self.domain:
@@ -284,41 +289,44 @@ def ctree(self, model=None):
284289
self.color_combo.addItems(self.COL_OPTIONS)
285290
self.color_combo.setCurrentIndex(self.regression_colors)
286291
self.openContext(self.domain.class_var)
287-
self.root_node = self.walkcreate(model.root, None)
288-
self.info.setText('{} nodes, {} leaves'.
289-
format(model.node_count(), model.leaf_count()))
292+
# self.root_node = self.walkcreate(model.root, None)
293+
self.root_node = self.walkcreate(self.tree_adapter.root)
294+
self.info.setText('{} nodes, {} leaves'.format(
295+
self.tree_adapter.num_nodes,
296+
len(self.tree_adapter.leaves(self.tree_adapter.root))))
290297
self.setup_scene()
291298
self.send("Selected Data", None)
292299
self.send(ANNOTATED_DATA_SIGNAL_NAME,
293300
create_annotated_table(self.dataset, []))
294301

295-
def walkcreate(self, node_inst, parent=None):
302+
def walkcreate(self, node, parent=None):
296303
"""Create a structure of tree nodes from the given model"""
297-
node = TreeNode(self.model, node_inst, parent)
298-
self.scene.addItem(node)
304+
node_obj = TreeNode(self.tree_adapter, node, parent)
305+
self.scene.addItem(node_obj)
299306
if parent:
300-
edge = GraphicsEdge(node1=parent, node2=node)
307+
edge = GraphicsEdge(node1=parent, node2=node_obj)
301308
self.scene.addItem(edge)
302309
parent.graph_add_edge(edge)
303-
for child_inst in node_inst.children:
310+
for child_inst in self.tree_adapter.children(node):
304311
if child_inst is not None:
305-
self.walkcreate(child_inst, node)
306-
return node
312+
self.walkcreate(child_inst, node_obj)
313+
return node_obj
307314

308315
def node_tooltip(self, node):
309-
return "<br>".join(to_html(rule)
310-
for rule in self.model.rule(node.node_inst))
316+
return "<br>".join(to_html(str(rule))
317+
for rule in self.tree_adapter.rules(node.node_inst))
311318

312319
def update_selection(self):
313320
if self.model is None:
314321
return
315322
nodes = [item.node_inst for item in self.scene.selectedItems()
316323
if isinstance(item, TreeNode)]
317-
data = self.model.get_instances(nodes)
324+
data = self.tree_adapter.get_instances_in_nodes(
325+
self.clf_dataset, nodes)
318326
self.send("Selected Data", data)
319-
self.send(ANNOTATED_DATA_SIGNAL_NAME,
320-
create_annotated_table(self.dataset,
321-
self.model.get_indices(nodes)))
327+
self.send(ANNOTATED_DATA_SIGNAL_NAME, create_annotated_table(
328+
self.dataset,
329+
self.tree_adapter.get_indices(nodes)))
322330

323331
def send_report(self):
324332
if not self.model:
@@ -344,8 +352,8 @@ def update_node_info(self, node):
344352
def update_node_info_cls(self, node):
345353
"""Update the printed contents of the node for classification trees"""
346354
node_inst = node.node_inst
347-
distr = node_inst.value
348-
total = len(node_inst.subset)
355+
distr = self.tree_adapter.get_distribution(node_inst)[0]
356+
total = self.tree_adapter.num_samples(node_inst)
349357
distr = distr / np.sum(distr)
350358
if self.target_class_index:
351359
tabs = distr[self.target_class_index - 1]
@@ -368,8 +376,8 @@ def update_node_info_cls(self, node):
368376
def update_node_info_reg(self, node):
369377
"""Update the printed contents of the node for regression trees"""
370378
node_inst = node.node_inst
371-
mean, var = node_inst.value
372-
insts = len(node_inst.subset)
379+
mean, var = self.tree_adapter.get_distribution(node_inst)[0]
380+
insts = self.tree_adapter.num_samples(node_inst)
373381
text = "{:.1f} ± {:.1f}<br/>".format(mean, var)
374382
text += "{} instances".format(insts)
375383
text = self._update_node_info_attr_name(node, text)
@@ -380,7 +388,7 @@ def toggle_node_color_cls(self):
380388
"""Update the node color for classification trees"""
381389
colors = self.scene.colors
382390
for node in self.scene.nodes():
383-
distr = node.node_inst.value
391+
distr = node.tree_adapter.get_distribution(node.node_inst)[0]
384392
total = sum(distr)
385393
if self.target_class_index:
386394
p = distr[self.target_class_index - 1] / total
@@ -401,39 +409,49 @@ def toggle_node_color_reg(self):
401409
for node in self.scene.nodes():
402410
node.backgroundBrush = brush
403411
elif self.regression_colors == self.COL_INSTANCE:
404-
max_insts = len(self.model.instances)
412+
max_insts = len(self.tree_adapter.get_instances_in_nodes(
413+
self.dataset, [self.tree_adapter.root]))
405414
for node in self.scene.nodes():
415+
node_insts = len(self.tree_adapter.get_instances_in_nodes(
416+
self.dataset, [node.node_inst]))
406417
node.backgroundBrush = QBrush(def_color.lighter(
407-
120 - 20 * len(node.node_inst.subset) / max_insts))
418+
120 - 20 * node_insts / max_insts))
408419
elif self.regression_colors == self.COL_MEAN:
409420
minv = np.nanmin(self.dataset.Y)
410421
maxv = np.nanmax(self.dataset.Y)
411422
fact = 1 / (maxv - minv) if minv != maxv else 1
412423
colors = self.scene.colors
413424
for node in self.scene.nodes():
414-
node.backgroundBrush = QBrush(
415-
colors[fact * (node.node_inst.value[0] - minv)])
425+
node_mean = self.tree_adapter.get_distribution(node.node_inst)[0][0]
426+
node.backgroundBrush = QBrush(colors[fact * (node_mean - minv)])
416427
else:
417428
nodes = list(self.scene.nodes())
418-
variances = [node.node_inst.value[1] for node in nodes]
429+
variances = [self.tree_adapter.get_distribution(node.node_inst)[0][1]
430+
for node in nodes]
419431
max_var = max(variances)
420432
for node, var in zip(nodes, variances):
421433
node.backgroundBrush = QBrush(def_color.lighter(
422434
120 - 20 * var / max_var))
423435
self.scene.update()
424436

437+
def _get_tree_adapter(self, model):
438+
if isinstance(model, SklModel):
439+
return SklTreeAdapter(model)
440+
return TreeAdapter(model)
441+
425442

426443
def test():
427444
"""Standalone test"""
428445
import sys
429446
from AnyQt.QtWidgets import QApplication
430-
# from Orange.classification.tree import TreeLearner
431-
from Orange.regression.tree import TreeLearner
447+
from Orange.classification.tree import TreeLearner, SklTreeLearner
448+
from Orange.regression.tree import TreeLearner, SklTreeRegressionLearner
432449
a = QApplication(sys.argv)
433450
ow = OWTreeGraph()
434-
# data = Table("iris")
435-
data = Table("housing")
436-
clf = TreeLearner()(data)
451+
data = Table("titanic")
452+
# data = Table("housing")[:30]
453+
clf = SklTreeLearner()(data)
454+
# clf = TreeLearner()(data)
437455
clf.instances = data
438456

439457
ow.ctree(clf)

Orange/widgets/visualize/owtreeviewer2d.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -454,15 +454,16 @@ def toggle_line_width(self):
454454
if self.root_node is None:
455455
return
456456

457-
model = self.model
458-
root_instances = len(model.instances)
457+
tree_adapter = self.root_node.tree_adapter
458+
root_instances = tree_adapter.num_samples(self.root_node.node_inst)
459459
width = 3
460460
for edge in self.scene.edges():
461-
num_inst = len(edge.node2.node_inst.subset)
461+
num_inst = tree_adapter.num_samples(edge.node2.node_inst)
462462
if self.line_width_method == 1:
463463
width = 8 * num_inst / root_instances
464464
elif self.line_width_method == 2:
465-
width = 8 * num_inst / len(edge.node1.node_inst.subset)
465+
width = 8 * num_inst / tree_adapter.num_samples(
466+
edge.node1.node_inst)
466467
edge.setPen(QPen(Qt.gray, width, Qt.SolidLine, Qt.RoundCap))
467468
self.scene.update()
468469

Orange/widgets/visualize/utils/tree/rules.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ def merge_with(self, rule):
3131
"""
3232
raise NotImplementedError()
3333

34+
@property
35+
def description(self):
36+
return str(self)
37+
3438

3539
class DiscreteRule(Rule):
3640
"""Discrete rule class for handling Indicator rules.
@@ -68,6 +72,10 @@ def merge_with(self, rule):
6872
warnings.warn('Merged two discrete rules `%s` and `%s`' % (self, rule))
6973
return rule
7074

75+
@property
76+
def description(self):
77+
return '{} {}'.format('=' if self.equals else '≠', self.value)
78+
7179
def __str__(self):
7280
return '{} {} {}'.format(
7381
self.attr_name, '=' if self.equals else '≠', self.value)
@@ -130,6 +138,10 @@ def merge_with(self, rule):
130138
lt_rule, gt_rule = (rule, self) if self.greater else (self, rule)
131139
return IntervalRule(self.attr_name, gt_rule, lt_rule)
132140

141+
@property
142+
def description(self):
143+
return '%s %.3f' % ('>' if self.greater else '≤', self.value)
144+
133145
def __str__(self):
134146
return '%s %s %.3f' % (
135147
self.attr_name, '>' if self.greater else '≤', self.value)
@@ -197,6 +209,15 @@ def merge_with(self, rule):
197209
self.left_rule.merge_with(rule.left_rule),
198210
self.right_rule.merge_with(rule.right_rule))
199211

212+
@property
213+
def description(self):
214+
return '∈ %s%.3f, %.3f%s' % (
215+
'[' if self.left_rule.inclusive else '(',
216+
self.left_rule.value,
217+
self.right_rule.value,
218+
']' if self.right_rule.inclusive else ')'
219+
)
220+
200221
def __str__(self):
201222
return '%s ∈ %s%.3f, %.3f%s' % (
202223
self.attr_name,

Orange/widgets/visualize/utils/tree/skltreeadapter.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class SklTreeAdapter(BaseTreeAdapter):
2828
FEATURE_UNDEFINED = -2
2929

3030
def __init__(self, model):
31+
self.model = model
3132
self._tree = model.skl_model.tree_
3233
self._domain = model.domain
3334

@@ -65,7 +66,15 @@ def __right_child(self, node):
6566
return self._tree.children_right[node]
6667

6768
def get_distribution(self, node):
68-
return self._tree.value[node]
69+
value = self._tree.value[node]
70+
# If regression tree, we have to compute variance by hand, we can
71+
# detect this because you can't have classification trees when there's
72+
# only one class
73+
if value.shape[1] == 1:
74+
var = np.var(self.get_instances_in_nodes(self.model.instances, node).Y)
75+
variances = np.array([(var * np.ones(value.shape[0]))]).T
76+
value = np.hstack((value, variances))
77+
return value
6978

7079
def get_impurity(self, node):
7180
return self._tree.impurity[node]
@@ -132,6 +141,9 @@ def rules(self, node):
132141
else:
133142
return []
134143

144+
def short_rule(self, node):
145+
return self.rules(node)[0].description
146+
135147
def attribute(self, node):
136148
feature_idx = self.splitting_attribute(node)
137149
if feature_idx != self.FEATURE_UNDEFINED:

Orange/widgets/visualize/utils/tree/treeadapter.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,10 @@ def rules(self, node):
164164
"""
165165
pass
166166

167+
@abstractmethod
168+
def short_rule(self, node):
169+
pass
170+
167171
@abstractmethod
168172
def attribute(self, node):
169173
"""Get the attribute that splits the given tree.
@@ -312,6 +316,9 @@ def get_impurity(self, node):
312316
def rules(self, node):
313317
return self.model.rule(node)
314318

319+
def short_rule(self, node):
320+
return node.description
321+
315322
def attribute(self, node):
316323
return node.attr
317324

0 commit comments

Comments
 (0)