Skip to content

Commit a5677e0

Browse files
OWTreeViewer: Use tree adapter for regression (missing variance)
1 parent 5792821 commit a5677e0

File tree

1 file changed

+15
-10
lines changed

1 file changed

+15
-10
lines changed

Orange/widgets/visualize/owtreeviewer.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,9 @@ def update_node_info_cls(self, node):
359359
def update_node_info_reg(self, node):
360360
"""Update the printed contents of the node for regression trees"""
361361
node_inst = node.node_inst
362-
print(self.tree_adapter.get_distribution(node_inst)[0])
363-
mean, var = node_inst.value
364-
insts = len(node_inst.subset)
362+
# TODO calculate variance in tree adapter
363+
mean, var = self.tree_adapter.get_distribution(node_inst)[0][0], 0.
364+
insts = len(self.tree_adapter.get_instances_in_nodes(self.dataset, [node_inst]))
365365
text = "{:.1f} ± {:.1f}<br/>".format(mean, var)
366366
text += "{} instances".format(insts)
367367
text = self._update_node_info_attr_name(node, text)
@@ -393,20 +393,24 @@ def toggle_node_color_reg(self):
393393
for node in self.scene.nodes():
394394
node.backgroundBrush = brush
395395
elif self.regression_colors == self.COL_INSTANCE:
396-
max_insts = len(self.model.instances)
396+
max_insts = len(self.tree_adapter.get_instances_in_nodes(
397+
self.dataset, [self.tree_adapter.root]))
397398
for node in self.scene.nodes():
399+
node_insts = len(self.tree_adapter.get_instances_in_nodes(
400+
self.dataset, [node.node_inst]))
398401
node.backgroundBrush = QBrush(def_color.lighter(
399-
120 - 20 * len(node.node_inst.subset) / max_insts))
402+
120 - 20 * node_insts / max_insts))
400403
elif self.regression_colors == self.COL_MEAN:
401404
minv = np.nanmin(self.dataset.Y)
402405
maxv = np.nanmax(self.dataset.Y)
403406
fact = 1 / (maxv - minv) if minv != maxv else 1
404407
colors = self.scene.colors
405408
for node in self.scene.nodes():
406-
node.backgroundBrush = QBrush(
407-
colors[fact * (node.node_inst.value[0] - minv)])
409+
node_mean = self.tree_adapter.get_distribution(node.node_inst)[0][0]
410+
node.backgroundBrush = QBrush(colors[fact * (node_mean - minv)])
408411
else:
409412
nodes = list(self.scene.nodes())
413+
# TODO Get variance from tree adapter
410414
variances = [node.node_inst.value[1] for node in nodes]
411415
max_var = max(variances)
412416
for node, var in zip(nodes, variances):
@@ -428,9 +432,10 @@ def test():
428432
from Orange.regression.tree import TreeLearner, SklTreeRegressionLearner
429433
a = QApplication(sys.argv)
430434
ow = OWTreeGraph()
431-
data = Table("iris")
432-
# data = Table("housing")
433-
clf = SklTreeLearner()(data)
435+
# data = Table("iris")
436+
data = Table("housing")
437+
clf = SklTreeRegressionLearner()(data)
438+
# clf = TreeLearner()(data)
434439
clf.instances = data
435440

436441
ow.ctree(clf)

0 commit comments

Comments
 (0)