88from AnyQt .QtGui import QColor , QBrush , QPen , QFontMetrics
99from AnyQt .QtCore import Qt , QPointF , QSizeF , QRectF
1010
11- from Orange .tree import TreeModel
11+ from Orange .base import TreeModel , SklModel
1212from Orange .widgets .visualize .owtreeviewer2d import \
1313 GraphicsNode , GraphicsEdge , OWTreeViewer2D
1414from Orange .widgets .utils import to_html
2020from Orange .widgets .utils .colorpalette import ContinuousPaletteGenerator
2121from Orange .widgets .utils .annotated_data import (create_annotated_table ,
2222 ANNOTATED_DATA_SIGNAL_NAME )
23+ from Orange .widgets .visualize .utils .tree .skltreeadapter import SklTreeAdapter
24+ from Orange .widgets .visualize .utils .tree .treeadapter import TreeAdapter
2325
2426
2527class PieChart (QGraphicsRectItem ):
@@ -66,20 +68,21 @@ class TreeNode(GraphicsNode):
6668 # Methods are documented in PyQt documentation
6769 # pylint: disable=missing-docstring
6870
69- def __init__ (self , model , node_inst , parent = None ):
71+ def __init__ (self , tree_adapter , node_inst , parent = None ):
7072 super ().__init__ (parent )
71- self .model = model
73+ self .tree_adapter = tree_adapter
74+ self .model = self .tree_adapter .model
7275 self .node_inst = node_inst
7376
7477 fm = QFontMetrics (self .document ().defaultFont ())
75- attr = node_inst . attr
78+ attr = self . tree_adapter . attribute ( node_inst )
7679 self .attr_text_w = fm .width (attr .name if attr else "" )
7780 self .attr_text_h = fm .lineSpacing ()
7881 self .line_descent = fm .descent ()
7982 self ._rect = None
8083
81- if model .domain .class_var .is_discrete :
82- self .pie = PieChart (node_inst . value , 8 , self )
84+ if self . model .domain .class_var .is_discrete :
85+ self .pie = PieChart (self . tree_adapter . get_distribution ( node_inst )[ 0 ] , 8 , self )
8386 else :
8487 self .pie = None
8588
@@ -90,7 +93,7 @@ def update_contents(self):
9093 self .droplet .setPos (self .rect ().center ().x (), self .rect ().height ())
9194 self .droplet .setVisible (bool (self .branches ))
9295 fm = QFontMetrics (self .document ().defaultFont ())
93- attr = self .node_inst . attr
96+ attr = self .tree_adapter . attribute ( self . node_inst )
9497 self .attr_text_w = fm .width (attr .name if attr else "" )
9598 self .attr_text_h = fm .lineSpacing ()
9699 self .line_descent = fm .descent ()
@@ -129,7 +132,7 @@ def paint(self, painter, option, widget=None):
129132 font = self .document ().defaultFont ()
130133 painter .setFont (font )
131134 if self .parent :
132- draw_text = self .node_inst . description
135+ draw_text = str ( self .tree_adapter . short_rule ( self . node_inst ))
133136 if self .parent .x () > self .x (): # node is to the left
134137 fm = QFontMetrics (font )
135138 x = rect .width () / 2 - fm .width (draw_text ) - 4
@@ -140,7 +143,7 @@ def paint(self, painter, option, widget=None):
140143 painter .setBrush (self .backgroundBrush )
141144 painter .setPen (QPen (Qt .black , 3 if self .isSelected () else 0 ))
142145 adjrect = rect .adjusted (- 3 , 0 , 0 , 0 )
143- if not self .node_inst . children :
146+ if not self .tree_adapter . has_children ( self . node_inst ) :
144147 painter .drawRoundedRect (adjrect , 4 , 4 )
145148 else :
146149 painter .drawRect (adjrect )
@@ -188,6 +191,7 @@ def __init__(self):
188191 self .domain = None
189192 self .dataset = None
190193 self .clf_dataset = None
194+ self .tree_adapter = None
191195
192196 self .color_label = QLabel ("Target class: " )
193197 combo = self .color_combo = gui .OrangeComboBox ()
@@ -211,9 +215,8 @@ def set_node_info(self):
211215 node .set_rect (QRectF (rect .x (), rect .y (), w , rect .height ()))
212216 self .scene .fix_pos (self .root_node , 10 , 10 )
213217
214- @staticmethod
215- def _update_node_info_attr_name (node , text ):
216- attr = node .node_inst .attr
218+ def _update_node_info_attr_name (self , node , text ):
219+ attr = self .tree_adapter .attribute (node .node_inst )
217220 if attr is not None :
218221 text += "<hr/>{}" .format (attr .name )
219222 return text
@@ -263,7 +266,9 @@ def ctree(self, model=None):
263266 self .info .setText ('No tree.' )
264267 self .root_node = None
265268 self .dataset = None
269+ self .tree_adapter = None
266270 else :
271+ self .tree_adapter = self ._get_tree_adapter (model )
267272 self .domain = model .domain
268273 self .dataset = model .instances
269274 if self .dataset is not None and self .dataset .domain != self .domain :
@@ -284,41 +289,44 @@ def ctree(self, model=None):
284289 self .color_combo .addItems (self .COL_OPTIONS )
285290 self .color_combo .setCurrentIndex (self .regression_colors )
286291 self .openContext (self .domain .class_var )
287- self .root_node = self .walkcreate (model .root , None )
288- self .info .setText ('{} nodes, {} leaves' .
289- format (model .node_count (), model .leaf_count ()))
292+ # self.root_node = self.walkcreate(model.root, None)
293+ self .root_node = self .walkcreate (self .tree_adapter .root )
294+ self .info .setText ('{} nodes, {} leaves' .format (
295+ self .tree_adapter .num_nodes ,
296+ len (self .tree_adapter .leaves (self .tree_adapter .root ))))
290297 self .setup_scene ()
291298 self .send ("Selected Data" , None )
292299 self .send (ANNOTATED_DATA_SIGNAL_NAME ,
293300 create_annotated_table (self .dataset , []))
294301
295- def walkcreate (self , node_inst , parent = None ):
302+ def walkcreate (self , node , parent = None ):
296303 """Create a structure of tree nodes from the given model"""
297- node = TreeNode (self .model , node_inst , parent )
298- self .scene .addItem (node )
304+ node_obj = TreeNode (self .tree_adapter , node , parent )
305+ self .scene .addItem (node_obj )
299306 if parent :
300- edge = GraphicsEdge (node1 = parent , node2 = node )
307+ edge = GraphicsEdge (node1 = parent , node2 = node_obj )
301308 self .scene .addItem (edge )
302309 parent .graph_add_edge (edge )
303- for child_inst in node_inst . children :
310+ for child_inst in self . tree_adapter . children ( node ) :
304311 if child_inst is not None :
305- self .walkcreate (child_inst , node )
306- return node
312+ self .walkcreate (child_inst , node_obj )
313+ return node_obj
307314
308315 def node_tooltip (self , node ):
309- return "<br>" .join (to_html (rule )
310- for rule in self .model . rule (node .node_inst ))
316+ return "<br>" .join (to_html (str ( rule ) )
317+ for rule in self .tree_adapter . rules (node .node_inst ))
311318
312319 def update_selection (self ):
313320 if self .model is None :
314321 return
315322 nodes = [item .node_inst for item in self .scene .selectedItems ()
316323 if isinstance (item , TreeNode )]
317- data = self .model .get_instances (nodes )
324+ data = self .tree_adapter .get_instances_in_nodes (
325+ self .clf_dataset , nodes )
318326 self .send ("Selected Data" , data )
319- self .send (ANNOTATED_DATA_SIGNAL_NAME ,
320- create_annotated_table ( self .dataset ,
321- self .model .get_indices (nodes )))
327+ self .send (ANNOTATED_DATA_SIGNAL_NAME , create_annotated_table (
328+ self .dataset ,
329+ self .tree_adapter .get_indices (nodes )))
322330
323331 def send_report (self ):
324332 if not self .model :
@@ -344,8 +352,8 @@ def update_node_info(self, node):
344352 def update_node_info_cls (self , node ):
345353 """Update the printed contents of the node for classification trees"""
346354 node_inst = node .node_inst
347- distr = node_inst . value
348- total = len (node_inst . subset )
355+ distr = self . tree_adapter . get_distribution ( node_inst )[ 0 ]
356+ total = self . tree_adapter . num_samples (node_inst )
349357 distr = distr / np .sum (distr )
350358 if self .target_class_index :
351359 tabs = distr [self .target_class_index - 1 ]
@@ -368,8 +376,8 @@ def update_node_info_cls(self, node):
368376 def update_node_info_reg (self , node ):
369377 """Update the printed contents of the node for regression trees"""
370378 node_inst = node .node_inst
371- mean , var = node_inst . value
372- insts = len (node_inst . subset )
379+ mean , var = self . tree_adapter . get_distribution ( node_inst )[ 0 ]
380+ insts = self . tree_adapter . num_samples (node_inst )
373381 text = "{:.1f} ± {:.1f}<br/>" .format (mean , var )
374382 text += "{} instances" .format (insts )
375383 text = self ._update_node_info_attr_name (node , text )
@@ -380,7 +388,7 @@ def toggle_node_color_cls(self):
380388 """Update the node color for classification trees"""
381389 colors = self .scene .colors
382390 for node in self .scene .nodes ():
383- distr = node .node_inst . value
391+ distr = node .tree_adapter . get_distribution ( node . node_inst )[ 0 ]
384392 total = sum (distr )
385393 if self .target_class_index :
386394 p = distr [self .target_class_index - 1 ] / total
@@ -401,39 +409,49 @@ def toggle_node_color_reg(self):
401409 for node in self .scene .nodes ():
402410 node .backgroundBrush = brush
403411 elif self .regression_colors == self .COL_INSTANCE :
404- max_insts = len (self .model .instances )
412+ max_insts = len (self .tree_adapter .get_instances_in_nodes (
413+ self .dataset , [self .tree_adapter .root ]))
405414 for node in self .scene .nodes ():
415+ node_insts = len (self .tree_adapter .get_instances_in_nodes (
416+ self .dataset , [node .node_inst ]))
406417 node .backgroundBrush = QBrush (def_color .lighter (
407- 120 - 20 * len ( node . node_inst . subset ) / max_insts ))
418+ 120 - 20 * node_insts / max_insts ))
408419 elif self .regression_colors == self .COL_MEAN :
409420 minv = np .nanmin (self .dataset .Y )
410421 maxv = np .nanmax (self .dataset .Y )
411422 fact = 1 / (maxv - minv ) if minv != maxv else 1
412423 colors = self .scene .colors
413424 for node in self .scene .nodes ():
414- node . backgroundBrush = QBrush (
415- colors [fact * (node . node_inst . value [ 0 ] - minv )])
425+ node_mean = self . tree_adapter . get_distribution ( node . node_inst )[ 0 ][ 0 ]
426+ node . backgroundBrush = QBrush ( colors [fact * (node_mean - minv )])
416427 else :
417428 nodes = list (self .scene .nodes ())
418- variances = [node .node_inst .value [1 ] for node in nodes ]
429+ variances = [self .tree_adapter .get_distribution (node .node_inst )[0 ][1 ]
430+ for node in nodes ]
419431 max_var = max (variances )
420432 for node , var in zip (nodes , variances ):
421433 node .backgroundBrush = QBrush (def_color .lighter (
422434 120 - 20 * var / max_var ))
423435 self .scene .update ()
424436
437+ def _get_tree_adapter (self , model ):
438+ if isinstance (model , SklModel ):
439+ return SklTreeAdapter (model )
440+ return TreeAdapter (model )
441+
425442
426443def test ():
427444 """Standalone test"""
428445 import sys
429446 from AnyQt .QtWidgets import QApplication
430- # from Orange.classification.tree import TreeLearner
431- from Orange .regression .tree import TreeLearner
447+ from Orange .classification .tree import TreeLearner , SklTreeLearner
448+ from Orange .regression .tree import TreeLearner , SklTreeRegressionLearner
432449 a = QApplication (sys .argv )
433450 ow = OWTreeGraph ()
434- # data = Table("iris")
435- data = Table ("housing" )
436- clf = TreeLearner ()(data )
451+ data = Table ("titanic" )
452+ # data = Table("housing")[:30]
453+ clf = SklTreeLearner ()(data )
454+ # clf = TreeLearner()(data)
437455 clf .instances = data
438456
439457 ow .ctree (clf )
0 commit comments