Skip to content

Commit ccd8a40

Browse files
Merge pull request #3775 from thocevar/pythagorean
[FIX] OWPythagorasTree: Enable node selection from forests with categorical variables.
2 parents 087031b + e50d114 commit ccd8a40

File tree

3 files changed

+29
-19
lines changed

3 files changed

+29
-19
lines changed

Orange/widgets/visualize/tests/test_owpythagorastree.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,14 @@
44

55
from os import path
66

7+
from Orange.classification.random_forest import RandomForestLearner
78
from Orange.data import Table
89
from Orange.modelling import TreeLearner
10+
from Orange.regression.random_forest import RandomForestRegressionLearner
911
from Orange.widgets.tests.base import WidgetTest, WidgetOutputsTestMixin
1012
from Orange.widgets.tests.utils import simulate
1113
from Orange.widgets.visualize.owpythagorastree import OWPythagorasTree
14+
from Orange.widgets.visualize.owpythagoreanforest import OWPythagoreanForest
1215
from Orange.widgets.visualize.pythagorastreeviewer import (
1316
PythagorasTree,
1417
Point,
@@ -359,3 +362,24 @@ def _callback():
359362
# Check that individual squares all have the same color
360363
colors_same = [self._check_all_same(x) for x in zip(*colors)]
361364
self.assertTrue(all(colors_same))
365+
366+
def test_forest_tree_table(self):
367+
titanic_data = Table('titanic')[::50]
368+
titanic = RandomForestLearner(n_estimators=3)(titanic_data)
369+
titanic.instances = titanic_data
370+
371+
housing_data = Table('housing')[:10]
372+
housing = RandomForestRegressionLearner(n_estimators=3)(housing_data)
373+
housing.instances = housing_data
374+
375+
forest_w = self.create_widget(OWPythagoreanForest)
376+
for data in (housing, titanic):
377+
self.send_signal(forest_w.Inputs.random_forest, data, widget=forest_w)
378+
tree = forest_w.forest_model[0].model
379+
380+
tree_w = self.widget
381+
self.send_signal(tree_w.Inputs.tree, tree, widget=tree_w)
382+
square = [i for i in tree_w.scene.items() if isinstance(i, SquareGraphicsItem)][-1]
383+
square.setSelected(True)
384+
tab = self.get_output(tree_w.Outputs.selected_data, widget=tree_w)
385+
self.assertGreater(len(tab), 0)

Orange/widgets/visualize/utils/tree/skltreeadapter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,8 +151,9 @@ def short_rule(self, node):
151151

152152
def attribute(self, node):
153153
feature_idx = self.splitting_attribute(node)
154-
if feature_idx != self.FEATURE_UNDEFINED:
155-
return self.domain.attributes[self.splitting_attribute(node)]
154+
if feature_idx == self.FEATURE_UNDEFINED:
155+
return None
156+
return self.domain.attributes[self.splitting_attribute(node)]
156157

157158
def splitting_attribute(self, node):
158159
return self._tree.feature[node]
@@ -236,7 +237,7 @@ def assign(node_id, indices):
236237
feature_idx = self._tree.feature[node_id]
237238
thresh = self._tree.threshold[node_id]
238239

239-
column = self.instances.X[indices, feature_idx]
240+
column = self.instances_transformed.X[indices, feature_idx]
240241
leftmask = column <= thresh
241242
leftind = assign(self._tree.children_left[node_id],
242243
indices[leftmask])

Orange/widgets/visualize/utils/tree/treeadapter.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self, model):
2222
self.model = model
2323
self.domain = model.domain
2424
self.instances = model.instances
25+
self.instances_transformed = self.instances.transform(self.domain)
2526

2627
@abstractmethod
2728
def weight(self, node):
@@ -40,7 +41,6 @@ def weight(self, node):
4041
The weight of the node relative to its siblings.
4142
4243
"""
43-
pass
4444

4545
@abstractmethod
4646
def num_samples(self, node):
@@ -56,7 +56,6 @@ def num_samples(self, node):
5656
int
5757
5858
"""
59-
pass
6059

6160
@abstractmethod
6261
def parent(self, node):
@@ -71,7 +70,6 @@ def parent(self, node):
7170
object
7271
7372
"""
74-
pass
7573

7674
@abstractmethod
7775
def has_children(self, node):
@@ -86,7 +84,6 @@ def has_children(self, node):
8684
bool
8785
8886
"""
89-
pass
9087

9188
def is_leaf(self, node):
9289
"""Check if the given node is a leaf node.
@@ -116,7 +113,6 @@ def children(self, node):
116113
A iterable object containing the labels of the child nodes.
117114
118115
"""
119-
pass
120116

121117
def reverse_children(self, node):
122118
"""Reverse children of a given node.
@@ -125,12 +121,10 @@ def reverse_children(self, node):
125121
----------
126122
node : object
127123
"""
128-
pass
129124

130125
def shuffle_children(self):
131126
"""Randomly shuffle node's children in the entire tree.
132127
"""
133-
pass
134128

135129
@abstractmethod
136130
def get_distribution(self, node):
@@ -151,7 +145,6 @@ def get_distribution(self, node):
151145
the number of nodes that belong to a given class inside the node.
152146
153147
"""
154-
pass
155148

156149
@abstractmethod
157150
def get_impurity(self, node):
@@ -166,7 +159,6 @@ def get_impurity(self, node):
166159
object
167160
168161
"""
169-
pass
170162

171163
@abstractmethod
172164
def rules(self, node):
@@ -182,7 +174,6 @@ def rules(self, node):
182174
A list of Rule objects, can be of any type.
183175
184176
"""
185-
pass
186177

187178
@abstractmethod
188179
def short_rule(self, node):
@@ -200,7 +191,6 @@ def attribute(self, node):
200191
-------
201192
202193
"""
203-
pass
204194

205195
def is_root(self, node):
206196
"""Check if a given node is the root node.
@@ -227,7 +217,6 @@ def leaves(self, node):
227217
-------
228218
229219
"""
230-
pass
231220

232221
@abstractmethod
233222
def get_instances_in_nodes(self, dataset, nodes):
@@ -245,7 +234,6 @@ def get_instances_in_nodes(self, dataset, nodes):
245234
-------
246235
247236
"""
248-
pass
249237

250238
@abstractmethod
251239
def get_indices(self, nodes):
@@ -261,7 +249,6 @@ def max_depth(self):
261249
int
262250
263251
"""
264-
pass
265252

266253
@property
267254
@abstractmethod
@@ -276,7 +263,6 @@ def num_nodes(self):
276263
int
277264
278265
"""
279-
pass
280266

281267
@property
282268
@abstractmethod
@@ -288,7 +274,6 @@ def root(self):
288274
object
289275
290276
"""
291-
pass
292277

293278

294279
class TreeAdapter(BaseTreeAdapter):

0 commit comments

Comments
 (0)