Skip to content

Commit e769ff3

Browse files
TreeAdapter: Return [mean, variance] for values
1 parent 530779b commit e769ff3

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

Orange/widgets/visualize/owtreeviewer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -376,8 +376,7 @@ def update_node_info_cls(self, node):
376376
def update_node_info_reg(self, node):
377377
"""Update the printed contents of the node for regression trees"""
378378
node_inst = node.node_inst
379-
# TODO calculate variance in tree adapter
380-
mean, var = self.tree_adapter.get_distribution(node_inst)[0][0], 0.
379+
mean, var = self.tree_adapter.get_distribution(node_inst)[0]
381380
insts = len(self.tree_adapter.get_instances_in_nodes(self.dataset, [node_inst]))
382381
text = "{:.1f} ± {:.1f}<br/>".format(mean, var)
383382
text += "{} instances".format(insts)
@@ -427,8 +426,8 @@ def toggle_node_color_reg(self):
427426
node.backgroundBrush = QBrush(colors[fact * (node_mean - minv)])
428427
else:
429428
nodes = list(self.scene.nodes())
430-
# TODO Get variance from tree adapter
431-
variances = [node.node_inst.value[1] for node in nodes]
429+
variances = [self.tree_adapter.get_distribution(node.node_inst)[0][1]
430+
for node in nodes]
432431
max_var = max(variances)
433432
for node, var in zip(nodes, variances):
434433
node.backgroundBrush = QBrush(def_color.lighter(
@@ -450,7 +449,7 @@ def test():
450449
a = QApplication(sys.argv)
451450
ow = OWTreeGraph()
452451
# data = Table("iris")
453-
data = Table("housing")
452+
data = Table("housing")[:30]
454453
clf = SklTreeRegressionLearner()(data)
455454
# clf = TreeLearner()(data)
456455
clf.instances = data

Orange/widgets/visualize/owtreeviewer2d.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,6 @@ def toggle_line_width(self):
454454
if self.root_node is None:
455455
return
456456

457-
model = self.model
458457
tree_adapter = self.root_node.tree_adapter
459458
root_instances = len(tree_adapter.get_instances_in_nodes(
460459
self.dataset, [self.root_node.node_inst]))

Orange/widgets/visualize/utils/tree/skltreeadapter.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,15 @@ def __right_child(self, node):
6666
return self._tree.children_right[node]
6767

6868
def get_distribution(self, node):
69-
return self._tree.value[node]
69+
value = self._tree.value[node]
70+
# If regression tree, we have to compute variance by hand, we can
71+
# detect this because you can't have classification trees when there's
72+
# only one class
73+
if value.shape[1] == 1:
74+
var = np.var(self.get_instances_in_nodes(self.model.instances, node).Y)
75+
variances = np.array([(var * np.ones(value.shape[0]))]).T
76+
value = np.hstack((value, variances))
77+
return value
7078

7179
def get_impurity(self, node):
7280
return self._tree.impurity[node]

0 commit comments

Comments
 (0)