Skip to content

Commit f46b3e3

Browse files
committed
Take best tree for xgboost when num_trees=None choosen
1 parent 86ab34b commit f46b3e3

File tree

1 file changed

+11
-8
lines changed

1 file changed

+11
-8
lines changed

treeplot/treeplot.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424

2525
# %% Plot tree
26-
def plot(model, featnames=None, num_trees=0, plottype='horizontal', figsize=(25,25), verbose=3):
26+
def plot(model, featnames=None, num_trees=None, plottype='horizontal', figsize=(25,25), verbose=3):
2727
"""Make tree plot for the input model.
2828
2929
Parameters
@@ -32,8 +32,8 @@ def plot(model, featnames=None, num_trees=0, plottype='horizontal', figsize=(25,
3232
xgboost or randomforest model.
3333
featnames : list, optional
3434
list of feature names. The default is None.
35-
num_trees : int, default 0
36-
Specify the ordinal number of target tree
35+
num_trees : int, default None
36+
The best performing tree is choosen. Specify any other ordinal number for another target tree
3737
plottype : str, (default : 'horizontal')
3838
Works only in case of xgb model.
3939
* 'horizontal'
@@ -65,7 +65,7 @@ def plot(model, featnames=None, num_trees=0, plottype='horizontal', figsize=(25,
6565

6666

6767
# %% Plot tree
68-
def xgboost(model, featnames=None, num_trees=0, plottype='horizontal', figsize=(25,25), verbose=3):
68+
def xgboost(model, featnames=None, num_trees=None, plottype='horizontal', figsize=(25,25), verbose=3):
6969
"""Plot tree based on a xgboost.
7070
7171
Parameters
@@ -74,8 +74,8 @@ def xgboost(model, featnames=None, num_trees=0, plottype='horizontal', figsize=(
7474
xgboost model.
7575
featnames : list, optional
7676
list of feature names. The default is None.
77-
num_trees : int, default 0
78-
Specify the ordinal number of target tree
77+
num_trees : int, default None
78+
The best performing tree is choosen. Specify any other ordinal number for another target tree
7979
plottype : str, optional
8080
Make 'horizontal' or 'vertical' plot. The default is 'horizontal'.
8181
figsize: tuple, default (25,25)
@@ -101,6 +101,7 @@ def xgboost(model, featnames=None, num_trees=0, plottype='horizontal', figsize=(
101101

102102
if plottype=='horizontal': plottype='UD'
103103
if plottype=='vertical': plottype='LR'
104+
if num_trees is None: num_trees = model.best_iteration
104105

105106
try:
106107
fig, ax = plt.subplots(1, 1, figsize=figsize)
@@ -116,7 +117,7 @@ def xgboost(model, featnames=None, num_trees=0, plottype='horizontal', figsize=(
116117

117118

118119
# %% Plot tree
119-
def randomforest(model, featnames=None, num_trees=0, filepath='tree', export='png', resolution=100, figsize=(25,25), verbose=3):
120+
def randomforest(model, featnames=None, num_trees=None, filepath='tree', export='png', resolution=100, figsize=(25,25), verbose=3):
120121
"""Plot tree based on a randomforest.
121122
122123
Parameters
@@ -149,11 +150,13 @@ def randomforest(model, featnames=None, num_trees=0, filepath='tree', export='pn
149150
ax=None
150151
dotfile = None
151152
pngfile = None
153+
if num_trees is None: num_trees = 0
154+
152155
# Check model
153156
_check_model(model, 'randomforest')
154157
# Set env
155158
_set_graphviz_path()
156-
159+
157160
if export is not None:
158161
dotfile = filepath + '.dot'
159162
pngfile = filepath + '.png'

0 commit comments

Comments
 (0)