2323
2424
2525# %% Plot tree
26- def plot (model , featnames = None , num_trees = 0 , plottype = 'horizontal' , figsize = (25 ,25 ), verbose = 3 ):
26+ def plot (model , featnames = None , num_trees = None , plottype = 'horizontal' , figsize = (25 ,25 ), verbose = 3 ):
2727 """Make tree plot for the input model.
2828
2929 Parameters
@@ -32,8 +32,8 @@ def plot(model, featnames=None, num_trees=0, plottype='horizontal', figsize=(25,
3232 xgboost or randomforest model.
3333 featnames : list, optional
3434 list of feature names. The default is None.
35- num_trees : int, default 0
36- Specify the ordinal number of target tree
35+ num_trees : int, default None
36+ The best performing tree is choosen. Specify any other ordinal number for another target tree
3737 plottype : str, (default : 'horizontal')
3838 Works only in case of xgb model.
3939 * 'horizontal'
@@ -65,7 +65,7 @@ def plot(model, featnames=None, num_trees=0, plottype='horizontal', figsize=(25,
6565
6666
6767# %% Plot tree
68- def xgboost (model , featnames = None , num_trees = 0 , plottype = 'horizontal' , figsize = (25 ,25 ), verbose = 3 ):
68+ def xgboost (model , featnames = None , num_trees = None , plottype = 'horizontal' , figsize = (25 ,25 ), verbose = 3 ):
6969 """Plot tree based on a xgboost.
7070
7171 Parameters
@@ -74,8 +74,8 @@ def xgboost(model, featnames=None, num_trees=0, plottype='horizontal', figsize=(
7474 xgboost model.
7575 featnames : list, optional
7676 list of feature names. The default is None.
77- num_trees : int, default 0
78- Specify the ordinal number of target tree
77+ num_trees : int, default None
78+ The best performing tree is choosen. Specify any other ordinal number for another target tree
7979 plottype : str, optional
8080 Make 'horizontal' or 'vertical' plot. The default is 'horizontal'.
8181 figsize: tuple, default (25,25)
@@ -101,6 +101,7 @@ def xgboost(model, featnames=None, num_trees=0, plottype='horizontal', figsize=(
101101
102102 if plottype == 'horizontal' : plottype = 'UD'
103103 if plottype == 'vertical' : plottype = 'LR'
104+ if num_trees is None : num_trees = model .best_iteration
104105
105106 try :
106107 fig , ax = plt .subplots (1 , 1 , figsize = figsize )
@@ -116,7 +117,7 @@ def xgboost(model, featnames=None, num_trees=0, plottype='horizontal', figsize=(
116117
117118
118119# %% Plot tree
119- def randomforest (model , featnames = None , num_trees = 0 , filepath = 'tree' , export = 'png' , resolution = 100 , figsize = (25 ,25 ), verbose = 3 ):
120+ def randomforest (model , featnames = None , num_trees = None , filepath = 'tree' , export = 'png' , resolution = 100 , figsize = (25 ,25 ), verbose = 3 ):
120121 """Plot tree based on a randomforest.
121122
122123 Parameters
@@ -149,11 +150,13 @@ def randomforest(model, featnames=None, num_trees=0, filepath='tree', export='pn
149150 ax = None
150151 dotfile = None
151152 pngfile = None
153+ if num_trees is None : num_trees = 0
154+
152155 # Check model
153156 _check_model (model , 'randomforest' )
154157 # Set env
155158 _set_graphviz_path ()
156-
159+
157160 if export is not None :
158161 dotfile = filepath + '.dot'
159162 pngfile = filepath + '.png'
0 commit comments