@@ -57,6 +57,8 @@ class Outputs:
5757 graph_name = 'scene'
5858
5959 # Settings
60+ settingsHandler = settings .DomainContextHandler ()
61+
6062 depth_limit = settings .ContextSetting (10 )
6163 target_class_index = settings .ContextSetting (0 )
6264 size_calc_idx = settings .Setting (0 )
@@ -73,8 +75,7 @@ def __init__(self):
7375 super ().__init__ ()
7476 # Instance variables
7577 self .model = None
76- self .instances = None
77- self .clf_dataset = None
78+ self .data = None
7879 # The tree adapter instance which is passed from the outside
7980 self .tree_adapter = None
8081 self .legend = None
@@ -147,18 +148,12 @@ def __init__(self):
147148 @Inputs .tree
148149 def set_tree (self , model = None ):
149150 """When a different tree is given."""
151+ self .closeContext ()
150152 self .clear ()
151153 self .model = model
152154
153155 if model is not None :
154- self .instances = model .instances
155- # this bit is important for the regression classifier
156- if self .instances is not None and \
157- self .instances .domain != model .domain :
158- self .clf_dataset = self .instances .transform (self .model .domain )
159- else :
160- self .clf_dataset = self .instances
161-
156+ self .data = model .instances
162157 self .tree_adapter = self ._get_tree_adapter (self .model )
163158 self .ptree .clear ()
164159
@@ -177,30 +172,30 @@ def set_tree(self, model=None):
177172
178173 self ._update_main_area ()
179174
180- # The target class can also be passed from the meta properties
181- # This must be set after `_update_target_class_combo`
182- if hasattr (model , 'meta_target_class_index' ):
183- self .target_class_index = model .meta_target_class_index
184- self .update_colors ()
175+ self .openContext (self .model )
185176
186- # Get meta variables describing what the settings should look like
187- # if the tree is passed from the Pythagorean forest widget.
188- if hasattr (model , 'meta_size_calc_idx' ):
189- self .size_calc_idx = model .meta_size_calc_idx
190- self .update_size_calc ()
177+ self .update_depth ()
191178
192- # TODO There is still something wrong with this
193- # if hasattr(model, 'meta_depth_limit'):
194- # self.depth_limit = model.meta_depth_limit
195- # self.update_depth()
179+ # The forest widget sets the following attributes on the tree,
180+ # describing the settings on the forest widget. To keep the tree
181+ # looking the same as on the forest widget, we prefer these settings to
182+ # context settings, if set.
183+ if hasattr (model , "meta_target_class_index" ):
184+ self .target_class_index = model .meta_target_class_index
185+ self .update_colors ()
186+ if hasattr (model , "meta_size_calc_idx" ):
187+ self .size_calc_idx = model .meta_size_calc_idx
188+ self .update_size_calc ()
189+ if hasattr (model , "meta_depth_limit" ):
190+ self .depth_limit = model .meta_depth_limit
191+ self .update_depth ()
196192
197- self .Outputs .annotated_data .send (create_annotated_table (self .instances , None ))
193+ self .Outputs .annotated_data .send (create_annotated_table (self .data , None ))
198194
199195 def clear (self ):
200196 """Clear all relevant data from the widget."""
201197 self .model = None
202- self .instances = None
203- self .clf_dataset = None
198+ self .data = None
204199 self .tree_adapter = None
205200
206201 if self .legend is not None :
@@ -228,6 +223,8 @@ def update_size_calc(self):
228223 self .invalidate_tree ()
229224
230225 def redraw (self ):
226+ if self .data is None :
227+ return
231228 self .tree_adapter .shuffle_children ()
232229 self .invalidate_tree ()
233230
@@ -307,16 +304,21 @@ def onDeleteWidget(self):
307304
308305 def commit (self ):
309306 """Commit the selected data to output."""
310- if self .instances is None :
307+ if self .data is None :
311308 self .Outputs .selected_data .send (None )
312309 self .Outputs .annotated_data .send (None )
313310 return
314- nodes = [i .tree_node .label for i in self .scene .selectedItems ()
315- if isinstance (i , SquareGraphicsItem )]
311+
312+ nodes = [
313+ i .tree_node .label for i in self .scene .selectedItems ()
314+ if isinstance (i , SquareGraphicsItem )
315+ ]
316316 data = self .tree_adapter .get_instances_in_nodes (nodes )
317317 self .Outputs .selected_data .send (data )
318318 selected_indices = self .tree_adapter .get_indices (nodes )
319- self .Outputs .annotated_data .send (create_annotated_table (self .instances , selected_indices ))
319+ self .Outputs .annotated_data .send (
320+ create_annotated_table (self .data , selected_indices )
321+ )
320322
321323 def send_report (self ):
322324 """Send report."""
@@ -327,9 +329,9 @@ def _update_target_class_combo(self):
327329 label = [x for x in self .target_class_combo .parent ().children ()
328330 if isinstance (x , QLabel )][0 ]
329331
330- if self .instances .domain .has_discrete_class :
332+ if self .data .domain .has_discrete_class :
331333 label_text = 'Target class'
332- values = [c .title () for c in self .instances .domain .class_vars [0 ].values ]
334+ values = [c .title () for c in self .data .domain .class_vars [0 ].values ]
333335 values .insert (0 , 'None' )
334336 else :
335337 label_text = 'Node color'
@@ -342,7 +344,7 @@ def _update_legend_colors(self):
342344 if self .legend is not None :
343345 self .scene .removeItem (self .legend )
344346
345- if self .instances .domain .has_discrete_class :
347+ if self .data .domain .has_discrete_class :
346348 self ._classification_update_legend_colors ()
347349 else :
348350 self ._regression_update_legend_colors ()
@@ -375,14 +377,14 @@ def _get_colors_domain(domain):
375377
376378 # The colors are the class mean
377379 if self .target_class_index == 1 :
378- values = (np .min (self .clf_dataset .Y ), np .max (self .clf_dataset .Y ))
380+ values = (np .min (self .data .Y ), np .max (self .data .Y ))
379381 colors = _get_colors_domain (self .model .domain )
380382 while len (values ) != len (colors ):
381383 values .insert (1 , - 1 )
382384 items = list (zip (values , colors ))
383385 # Colors are the stddev
384386 elif self .target_class_index == 2 :
385- values = (0 , np .std (self .clf_dataset .Y ))
387+ values = (0 , np .std (self .data .Y ))
386388 colors = _get_colors_domain (self .model .domain )
387389 while len (values ) != len (colors ):
388390 values .insert (1 , - 1 )
0 commit comments