Skip to content

Commit 93cb1e1

Browse files
TreeAdapter: Return [mean, variance] for values
1 parent a5677e0 commit 93cb1e1

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

Orange/widgets/visualize/owtreeviewer.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,7 @@ def update_node_info_cls(self, node):
359359
def update_node_info_reg(self, node):
360360
"""Update the printed contents of the node for regression trees"""
361361
node_inst = node.node_inst
362-
# TODO calculate variance in tree adapter
363-
mean, var = self.tree_adapter.get_distribution(node_inst)[0][0], 0.
362+
mean, var = self.tree_adapter.get_distribution(node_inst)[0]
364363
insts = len(self.tree_adapter.get_instances_in_nodes(self.dataset, [node_inst]))
365364
text = "{:.1f} ± {:.1f}<br/>".format(mean, var)
366365
text += "{} instances".format(insts)
@@ -410,8 +409,8 @@ def toggle_node_color_reg(self):
410409
node.backgroundBrush = QBrush(colors[fact * (node_mean - minv)])
411410
else:
412411
nodes = list(self.scene.nodes())
413-
# TODO Get variance from tree adapter
414-
variances = [node.node_inst.value[1] for node in nodes]
412+
variances = [self.tree_adapter.get_distribution(node.node_inst)[0][1]
413+
for node in nodes]
415414
max_var = max(variances)
416415
for node, var in zip(nodes, variances):
417416
node.backgroundBrush = QBrush(def_color.lighter(
@@ -433,7 +432,7 @@ def test():
433432
a = QApplication(sys.argv)
434433
ow = OWTreeGraph()
435434
# data = Table("iris")
436-
data = Table("housing")
435+
data = Table("housing")[:30]
437436
clf = SklTreeRegressionLearner()(data)
438437
# clf = TreeLearner()(data)
439438
clf.instances = data

Orange/widgets/visualize/owtreeviewer2d.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,6 @@ def toggle_line_width(self):
454454
if self.root_node is None:
455455
return
456456

457-
model = self.model
458457
tree_adapter = self.root_node.tree_adapter
459458
root_instances = len(tree_adapter.get_instances_in_nodes(
460459
self.dataset, [self.root_node.node_inst]))

Orange/widgets/visualize/utils/tree/skltreeadapter.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,15 @@ def __right_child(self, node):
6666
return self._tree.children_right[node]
6767

6868
def get_distribution(self, node):
69-
return self._tree.value[node]
69+
value = self._tree.value[node]
70+
# If regression tree, we have to compute variance by hand, we can
71+
# detect this because you can't have classification trees when there's
72+
# only one class
73+
if value.shape[1] == 1:
74+
var = np.var(self.get_instances_in_nodes(self.model.instances, node).Y)
75+
variances = np.array([(var * np.ones(value.shape[0]))]).T
76+
value = np.hstack((value, variances))
77+
return value
7078

7179
def get_impurity(self, node):
7280
return self._tree.impurity[node]

0 commit comments

Comments
 (0)