Skip to content

Commit c90d9ce

Browse files
authored
Merge pull request #3777 from pavlin-policar/pythagoras-enh
[FIX] Minor improvements to pythagorean trees
2 parents d7b49f6 + 57d4ecc commit c90d9ce

File tree

4 files changed

+90
-50
lines changed

4 files changed

+90
-50
lines changed

Orange/widgets/visualize/owpythagorastree.py

Lines changed: 38 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ class Outputs:
5757
graph_name = 'scene'
5858

5959
# Settings
60+
settingsHandler = settings.DomainContextHandler()
61+
6062
depth_limit = settings.ContextSetting(10)
6163
target_class_index = settings.ContextSetting(0)
6264
size_calc_idx = settings.Setting(0)
@@ -73,8 +75,7 @@ def __init__(self):
7375
super().__init__()
7476
# Instance variables
7577
self.model = None
76-
self.instances = None
77-
self.clf_dataset = None
78+
self.data = None
7879
# The tree adapter instance which is passed from the outside
7980
self.tree_adapter = None
8081
self.legend = None
@@ -147,18 +148,12 @@ def __init__(self):
147148
@Inputs.tree
148149
def set_tree(self, model=None):
149150
"""When a different tree is given."""
151+
self.closeContext()
150152
self.clear()
151153
self.model = model
152154

153155
if model is not None:
154-
self.instances = model.instances
155-
# this bit is important for the regression classifier
156-
if self.instances is not None and \
157-
self.instances.domain != model.domain:
158-
self.clf_dataset = self.instances.transform(self.model.domain)
159-
else:
160-
self.clf_dataset = self.instances
161-
156+
self.data = model.instances
162157
self.tree_adapter = self._get_tree_adapter(self.model)
163158
self.ptree.clear()
164159

@@ -177,30 +172,30 @@ def set_tree(self, model=None):
177172

178173
self._update_main_area()
179174

180-
# The target class can also be passed from the meta properties
181-
# This must be set after `_update_target_class_combo`
182-
if hasattr(model, 'meta_target_class_index'):
183-
self.target_class_index = model.meta_target_class_index
184-
self.update_colors()
175+
self.openContext(self.model)
185176

186-
# Get meta variables describing what the settings should look like
187-
# if the tree is passed from the Pythagorean forest widget.
188-
if hasattr(model, 'meta_size_calc_idx'):
189-
self.size_calc_idx = model.meta_size_calc_idx
190-
self.update_size_calc()
177+
self.update_depth()
191178

192-
# TODO There is still something wrong with this
193-
# if hasattr(model, 'meta_depth_limit'):
194-
# self.depth_limit = model.meta_depth_limit
195-
# self.update_depth()
179+
# The forest widget sets the following attributes on the tree,
180+
# describing the settings on the forest widget. To keep the tree
181+
# looking the same as on the forest widget, we prefer these settings to
182+
# context settings, if set.
183+
if hasattr(model, "meta_target_class_index"):
184+
self.target_class_index = model.meta_target_class_index
185+
self.update_colors()
186+
if hasattr(model, "meta_size_calc_idx"):
187+
self.size_calc_idx = model.meta_size_calc_idx
188+
self.update_size_calc()
189+
if hasattr(model, "meta_depth_limit"):
190+
self.depth_limit = model.meta_depth_limit
191+
self.update_depth()
196192

197-
self.Outputs.annotated_data.send(create_annotated_table(self.instances, None))
193+
self.Outputs.annotated_data.send(create_annotated_table(self.data, None))
198194

199195
def clear(self):
200196
"""Clear all relevant data from the widget."""
201197
self.model = None
202-
self.instances = None
203-
self.clf_dataset = None
198+
self.data = None
204199
self.tree_adapter = None
205200

206201
if self.legend is not None:
@@ -228,6 +223,8 @@ def update_size_calc(self):
228223
self.invalidate_tree()
229224

230225
def redraw(self):
226+
if self.data is None:
227+
return
231228
self.tree_adapter.shuffle_children()
232229
self.invalidate_tree()
233230

@@ -307,16 +304,21 @@ def onDeleteWidget(self):
307304

308305
def commit(self):
309306
"""Commit the selected data to output."""
310-
if self.instances is None:
307+
if self.data is None:
311308
self.Outputs.selected_data.send(None)
312309
self.Outputs.annotated_data.send(None)
313310
return
314-
nodes = [i.tree_node.label for i in self.scene.selectedItems()
315-
if isinstance(i, SquareGraphicsItem)]
311+
312+
nodes = [
313+
i.tree_node.label for i in self.scene.selectedItems()
314+
if isinstance(i, SquareGraphicsItem)
315+
]
316316
data = self.tree_adapter.get_instances_in_nodes(nodes)
317317
self.Outputs.selected_data.send(data)
318318
selected_indices = self.tree_adapter.get_indices(nodes)
319-
self.Outputs.annotated_data.send(create_annotated_table(self.instances, selected_indices))
319+
self.Outputs.annotated_data.send(
320+
create_annotated_table(self.data, selected_indices)
321+
)
320322

321323
def send_report(self):
322324
"""Send report."""
@@ -327,9 +329,9 @@ def _update_target_class_combo(self):
327329
label = [x for x in self.target_class_combo.parent().children()
328330
if isinstance(x, QLabel)][0]
329331

330-
if self.instances.domain.has_discrete_class:
332+
if self.data.domain.has_discrete_class:
331333
label_text = 'Target class'
332-
values = [c.title() for c in self.instances.domain.class_vars[0].values]
334+
values = [c.title() for c in self.data.domain.class_vars[0].values]
333335
values.insert(0, 'None')
334336
else:
335337
label_text = 'Node color'
@@ -342,7 +344,7 @@ def _update_legend_colors(self):
342344
if self.legend is not None:
343345
self.scene.removeItem(self.legend)
344346

345-
if self.instances.domain.has_discrete_class:
347+
if self.data.domain.has_discrete_class:
346348
self._classification_update_legend_colors()
347349
else:
348350
self._regression_update_legend_colors()
@@ -375,14 +377,14 @@ def _get_colors_domain(domain):
375377

376378
# The colors are the class mean
377379
if self.target_class_index == 1:
378-
values = (np.min(self.clf_dataset.Y), np.max(self.clf_dataset.Y))
380+
values = (np.min(self.data.Y), np.max(self.data.Y))
379381
colors = _get_colors_domain(self.model.domain)
380382
while len(values) != len(colors):
381383
values.insert(1, -1)
382384
items = list(zip(values, colors))
383385
# Colors are the stddev
384386
elif self.target_class_index == 2:
385-
values = (0, np.std(self.clf_dataset.Y))
387+
values = (0, np.std(self.data.Y))
386388
colors = _get_colors_domain(self.model.domain)
387389
while len(values) != len(colors):
388390
values.insert(1, -1)

Orange/widgets/visualize/owpythagoreanforest.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Any, Callable, Optional
44

55
from AnyQt.QtCore import Qt, QRectF, QSize, QPointF, QSizeF, QModelIndex, \
6-
QItemSelection, QT_VERSION
6+
QItemSelection, QItemSelectionModel, QT_VERSION
77
from AnyQt.QtGui import QPainter, QPen, QColor, QBrush, QMouseEvent
88
from AnyQt.QtWidgets import QSizePolicy, QGraphicsScene, QLabel, QSlider, \
99
QListView, QStyledItemDelegate, QStyleOptionViewItem, QStyle
@@ -174,11 +174,15 @@ class Outputs:
174174
graph_name = 'scene'
175175

176176
# Settings
177+
settingsHandler = settings.DomainContextHandler()
178+
177179
depth_limit = settings.ContextSetting(10)
178180
target_class_index = settings.ContextSetting(0)
179181
size_calc_idx = settings.Setting(0)
180182
zoom = settings.Setting(200)
181183

184+
selected_index = settings.ContextSetting(None)
185+
182186
SIZE_CALCULATION = [
183187
('Normal', lambda x: x),
184188
('Square root', lambda x: sqrt(x)),
@@ -199,7 +203,6 @@ def __init__(self):
199203
self.rf_model = None
200204
self.forest = None
201205
self.instances = None
202-
self.clf_dataset = None
203206

204207
self.color_palette = None
205208

@@ -265,29 +268,32 @@ def __init__(self):
265268
@Inputs.random_forest
266269
def set_rf(self, model=None):
267270
"""When a different forest is given."""
271+
self.closeContext()
268272
self.clear()
269273
self.rf_model = model
270274

271275
if model is not None:
272276
self.forest = self._get_forest_adapter(self.rf_model)
273277
self.forest_model[:] = self.forest.trees
274-
275278
self.instances = model.instances
276-
# This bit is important for the regression classifier
277-
if self.instances is not None and self.instances.domain != model.domain:
278-
self.clf_dataset = self.instances.transform(self.rf_model.domain)
279-
else:
280-
self.clf_dataset = self.instances
281279

282280
self._update_info_box()
283281
self._update_target_class_combo()
284282
self._update_depth_slider()
285283

284+
self.openContext(model)
285+
# Restore item selection
286+
if self.selected_index is not None:
287+
index = self.list_view.model().index(self.selected_index)
288+
selection = QItemSelection(index, index)
289+
self.list_view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect)
290+
286291
def clear(self):
287292
"""Clear all relevant data from the widget."""
288293
self.rf_model = None
289294
self.forest = None
290295
self.forest_model.clear()
296+
self.selected_index = None
291297

292298
self._clear_info_box()
293299
self._clear_target_class_combo()
@@ -342,19 +348,19 @@ def onDeleteWidget(self):
342348
super().onDeleteWidget()
343349
self.clear()
344350

345-
def commit(self, selection):
346-
# type: (QItemSelection) -> None
351+
def commit(self, selection: QItemSelection) -> None:
347352
"""Commit the selected tree to output."""
348353
selected_indices = selection.indexes()
349354

350355
if not len(selected_indices):
356+
self.selected_index = None
351357
self.Outputs.tree.send(None)
352358
return
353359

354-
selected_index, = selection.indexes()
360+
# We only allow selecting a single tree so there will always be one index
361+
self.selected_index = selected_indices[0].row()
355362

356-
idx = selected_index.row()
357-
tree = self.rf_model.trees[idx]
363+
tree = self.rf_model.trees[self.selected_index]
358364
tree.instances = self.instances
359365
tree.meta_target_class_index = self.target_class_index
360366
tree.meta_size_calc_idx = self.size_calc_idx

Orange/widgets/visualize/tests/test_owpythagorastree.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -383,3 +383,15 @@ def test_forest_tree_table(self):
383383
square.setSelected(True)
384384
tab = self.get_output(tree_w.Outputs.selected_data, widget=tree_w)
385385
self.assertGreater(len(tab), 0)
386+
387+
def test_changing_data_restores_depth_from_previous_settings(self):
388+
titanic_data = Table("titanic")[::50]
389+
forest = RandomForestLearner(n_estimators=3)(titanic_data)
390+
forest.instances = titanic_data
391+
392+
self.send_signal(self.widget.Inputs.tree, forest.trees[0])
393+
self.widget.controls.depth_limit.setValue(1)
394+
395+
# The domain is still the same, so restore the depth limit from before
396+
self.send_signal(self.widget.Inputs.tree, forest.trees[1])
397+
self.assertEqual(self.widget.ptree._depth_limit, 1)

Orange/widgets/visualize/tests/test_owpythagoreanforest.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from unittest.mock import Mock
44

5-
from AnyQt.QtCore import Qt
5+
from AnyQt.QtCore import Qt, QItemSelection, QItemSelectionModel
66

77
from Orange.classification.random_forest import RandomForestLearner
88
from Orange.data import Table
@@ -201,3 +201,23 @@ def _callback():
201201
# Check that individual squares all have the same color
202202
colors_same = [self._check_all_same(x) for x in zip(*colors)]
203203
self.assertTrue(all(colors_same))
204+
205+
def select_tree(self, idx: int) -> None:
206+
list_view = self.widget.list_view
207+
index = list_view.model().index(idx)
208+
selection = QItemSelection(index, index)
209+
list_view.selectionModel().select(selection, QItemSelectionModel.ClearAndSelect)
210+
211+
def test_storing_selection(self):
212+
# Select one of the trees
213+
idx = 1
214+
self.send_signal(self.widget.Inputs.random_forest, self.titanic)
215+
self.select_tree(idx)
216+
# Clear input
217+
self.send_signal(self.widget.Inputs.random_forest, None)
218+
# Restore previous data; context settings should be restored
219+
self.send_signal(self.widget.Inputs.random_forest, self.titanic)
220+
221+
output = self.get_output(self.widget.Outputs.tree)
222+
self.assertIsNotNone(output)
223+
self.assertIs(output.skl_model, self.titanic.trees[idx].skl_model)

0 commit comments

Comments
 (0)