@@ -179,7 +179,7 @@ def node_split(self, sample_indice):
179179 if sortted_feature [i + 1 ] <= sortted_feature [i ] + self .EPSILON :
180180 continue
181181
182- if self .min_samples_leaf < n_samples / ( self .n_split_grid - 1 ):
182+ if self .min_samples_leaf < n_samples / max (( self .n_split_grid - 1 ), 2 ):
183183 if (i + 1 ) / n_samples < (split_point + 1 ) / (self .n_split_grid + 1 ):
184184 continue
185185 elif n_samples > 2 * self .min_samples_leaf :
@@ -312,54 +312,62 @@ def fit(self, x, y):
312312 "is_left" : False })
313313 return self
314314
315- def plot_tree (self , folder = "./results/" , name = "demo" , save_png = False , save_eps = False ):
315+ def plot_tree (self , draw_depth = np . inf , start_node_id = 1 , folder = "./results/" , name = "demo" , save_png = False , save_eps = False ):
316316
317+ idx = 0
318+ draw_subtree = {}
317319 draw_tree = copy .deepcopy (self .tree )
318- pending_node_list = [draw_tree [1 ]]
319- max_depth = 1 + np .max ([item ["depth" ] for key , item in self .tree .items ()])
320+ pending_node_list = [draw_tree [start_node_id ]]
321+ start_depth = draw_tree [start_node_id ]["depth" ]
322+ total_depth = 1 + min (np .max ([item ["depth" ] for key , item in self .tree .items ()]) - start_depth , draw_depth )
323+ max_depth = min (np .max ([item ["depth" ] for key , item in self .tree .items ()]), start_depth + draw_depth )
320324 while len (pending_node_list ) > 0 :
321325
322326 item = pending_node_list .pop ()
323- if item ["parent_id" ] is None :
327+ if item ["depth" ] > max_depth :
328+ continue
329+ if item ["parent_id" ] is None or idx == 0 :
324330 xy = (0.5 , 0 )
325331 parent_xy = None
326332 else :
327- parent_xy = draw_tree [item ["parent_id" ]]["xy" ]
333+ parent_xy = draw_subtree [item ["parent_id" ]]["xy" ]
328334 if item ["is_left" ]:
329- xy = (parent_xy [0 ] - 1 / 2 ** (item ["depth" ] + 1 ), 3 * item ["depth" ] / (3 * max_depth - 2 ))
335+ xy = (parent_xy [0 ] - 1 / 2 ** (item ["depth" ] - start_depth + 1 ), 3 * ( item ["depth" ] - start_depth ) / (3 * total_depth - 2 ))
330336 else :
331- xy = (parent_xy [0 ] + 1 / 2 ** (item ["depth" ] + 1 ), 3 * item ["depth" ] / (3 * max_depth - 2 ))
337+ xy = (parent_xy [0 ] + 1 / 2 ** (item ["depth" ] - start_depth + 1 ), 3 * (item ["depth" ] - start_depth ) / (3 * total_depth - 2 ))
338+ idx += 1
332339
340+ draw_subtree [item ["node_id" ]] = item
333341 if item ["is_leaf" ]:
334342 if is_regressor (self ):
335- draw_tree [item ["node_id" ]].update ({"xy" : xy ,
343+ draw_subtree [item ["node_id" ]].update ({"xy" : xy ,
336344 "parent_xy" : parent_xy ,
337345 "estimator" : item ["estimator" ],
338346 "label" : "____Node " + str (item ["node_id" ]) + "____" +
339347 "\n MSE: " + str (np .round (item ["impurity" ], 3 ))
340348 + "\n Size: " + str (int (item ["n_samples" ]))
341349 + "\n Mean: " + str (np .round (item ["value" ], 3 ))})
342350 elif is_classifier (self ):
343- draw_tree [item ["node_id" ]].update ({"xy" : xy ,
351+ draw_subtree [item ["node_id" ]].update ({"xy" : xy ,
344352 "parent_xy" : parent_xy ,
345353 "estimator" : item ["estimator" ],
346354 "label" : "____Node " + str (item ["node_id" ]) + "____" +
347355 "\n CEntropy: " + str (np .round (item ["impurity" ], 3 ))
348356 + "\n Size: " + str (int (item ["n_samples" ]))
349- + "\n Mean: " + str (np .round (item ["value" ], 3 ))})
357+ + "\n Mean: " + str (np .round (item ["value" ], 3 ))})
350358 else :
351359 fill_width = len (self .feature_names [item ["feature" ]] + " <=" + str (np .round (item ["threshold" ], 3 )))
352360 fill_width = int (round ((fill_width - 2 ) / 2 ))
353361 if is_regressor (self ):
354- draw_tree [item ["node_id" ]].update ({"xy" : xy ,
362+ draw_subtree [item ["node_id" ]].update ({"xy" : xy ,
355363 "parent_xy" : parent_xy ,
356364 "label" : "_" * fill_width + "Node " + str (item ["node_id" ]) + "_" * fill_width
357365 + "\n " + self .feature_names [item ["feature" ]] + " <=" + str (np .round (item ["threshold" ], 3 ))
358366 + "\n MSE: " + str (np .round (item ["impurity" ], 3 ))
359367 + "\n Size: " + str (int (item ["n_samples" ]))
360368 + "\n Mean: " + str (np .round (item ["value" ], 3 ))})
361369 elif is_classifier (self ):
362- draw_tree [item ["node_id" ]].update ({"xy" : xy ,
370+ draw_subtree [item ["node_id" ]].update ({"xy" : xy ,
363371 "parent_xy" : parent_xy ,
364372 "label" : "_" * fill_width + "Node " + str (item ["node_id" ]) + "_" * fill_width
365373 + "\n " + self .feature_names [item ["feature" ]] + " <=" + str (np .round (item ["threshold" ], 3 ))
@@ -370,7 +378,7 @@ def plot_tree(self, folder="./results/", name="demo", save_png=False, save_eps=F
370378 pending_node_list .append (self .tree [item ["left_child_id" ]])
371379 pending_node_list .append (self .tree [item ["right_child_id" ]])
372380
373- fig = plt .figure (figsize = (2 ** max_depth , (max_depth - 0.8 ) * 2 ))
381+ fig = plt .figure (figsize = (2 ** total_depth , (total_depth - 0.8 ) * 2 ))
374382 tree = fig .add_axes ([0.0 , 0.0 , 1 , 1 ])
375383 ax_width = tree .get_window_extent ().width
376384 ax_height = tree .get_window_extent ().height
@@ -380,7 +388,8 @@ def plot_tree(self, folder="./results/", name="demo", save_png=False, save_eps=F
380388 values = np .array ([item ["value" ] for key , item in self .tree .items ()])
381389 min_value , max_value = values .min (), values .max ()
382390
383- for key , item in draw_tree .items ():
391+ idx = 0
392+ for key , item in draw_subtree .items ():
384393
385394 if max_value == min_value :
386395 if item ["is_leaf" ]:
@@ -396,22 +405,24 @@ def plot_tree(self, folder="./results/", name="demo", save_png=False, save_eps=F
396405 color = [int (round (alpha * c + (1 - alpha ) * 255 , 0 )) for c in color_list ]
397406
398407 kwargs = dict (bbox = {"fc" : '#%2x%2x%2x' % tuple (color ), "boxstyle" : "round" }, arrowprops = {"arrowstyle" : "<-" },
399- ha = 'center' , va = 'center' , zorder = 100 - 10 * item ["depth" ], xycoords = 'axes pixels' , fontsize = 14 )
408+ ha = 'center' , va = 'center' , zorder = 100 - 10 * ( item ["depth" ] - start_depth ) , xycoords = 'axes pixels' , fontsize = 14 )
400409
401- if item ["parent_id" ] is None :
410+ if item ["parent_id" ] is None or idx == 0 :
402411 tree .annotate (item ["label" ], (item ["xy" ][0 ] * ax_width , (1 - item ["xy" ][1 ]) * ax_height ), ** kwargs )
403412 else :
404413 if item ["is_left" ]:
405- tree .annotate (item ["label" ], ((item ["parent_xy" ][0 ] - 0.01 / 2 ** (item ["depth" ] + 1 )) * ax_width ,
406- (1 - item ["parent_xy" ][1 ] - 0.1 / max_depth ) * ax_height ),
414+ tree .annotate (item ["label" ], ((item ["parent_xy" ][0 ] - 0.01 / 2 ** (item ["depth" ] - start_depth + 1 )) * ax_width ,
415+ (1 - item ["parent_xy" ][1 ] - 0.1 / total_depth ) * ax_height ),
407416 (item ["xy" ][0 ] * ax_width , (1 - item ["xy" ][1 ]) * ax_height ), ** kwargs )
408417 else :
409- tree .annotate (item ["label" ], ((item ["parent_xy" ][0 ] + 0.01 / 2 ** (item ["depth" ] + 1 )) * ax_width ,
410- (1 - item ["parent_xy" ][1 ] - 0.1 / max_depth ) * ax_height ),
418+ tree .annotate (item ["label" ], ((item ["parent_xy" ][0 ] + 0.01 / 2 ** (item ["depth" ] - start_depth + 1 )) * ax_width ,
419+ (1 - item ["parent_xy" ][1 ] - 0.1 / total_depth ) * ax_height ),
411420 (item ["xy" ][0 ] * ax_width , (1 - item ["xy" ][1 ]) * ax_height ), ** kwargs )
421+ idx += 1
412422
413423 tree .set_axis_off ()
414424 plt .show ()
425+
415426 if max_depth > 0 :
416427 save_path = folder + name
417428 if save_eps :
0 commit comments