Skip to content

Commit d74e6a3

Browse files
committed
lightgbm added!
1 parent 430b0f7 commit d74e6a3

File tree

1 file changed

+40
-1
lines changed

1 file changed

+40
-1
lines changed

treeplot/treeplot.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,52 @@ def plot(model, featnames=None, num_trees=None, plottype='horizontal', figsize=(
5757
elif ('tree' in modelname) or ('forest' in modelname) or ('gradientboosting' in modelname):
5858
if verbose>=4: print('tree plotting pipeline.')
5959
ax = randomforest(model, featnames=featnames, num_trees=num_trees, figsize=figsize, verbose=verbose)
60+
if ('lgb' in modelname):
61+
ax = plot_lgb(model, featnames=featnames, num_trees=num_trees, figsize=figsize, verbose=verbose)
6062
else:
61-
print('[treeplot] Model %s not recognized.' %(modelname))
63+
print('[treeplot] >Model not recognized: %s' %(modelname))
6264
ax = None
6365

6466
return ax
6567

6668

69+
# %% Plot tree
70+
def plot_lgb(model, featnames=None, num_trees=None, figsize=(25,25), verbose=3):
71+
try:
72+
from lightgbm import plot_tree, plot_importance
73+
except:
74+
if verbose>=1: raise ImportError('lightgbm must be installed. Try to: <pip install lightgbm>')
75+
return None
76+
77+
# Check model
78+
_check_model(model, 'lgb')
79+
# Set env
80+
_set_graphviz_path()
81+
82+
if (num_trees is None) and hasattr(model, 'best_iteration_'):
83+
num_trees = model.best_iteration_
84+
if verbose>=3: print('[treeplot] >Best detected tree: %.0d' %(num_trees))
85+
elif num_trees is None:
86+
num_trees = 0
87+
88+
ax1 = None
89+
try:
90+
fig, ax1 = plt.subplots(1, 1, figsize=figsize)
91+
plot_tree(model, tree_index=num_trees, dpi=200, ax=ax1)
92+
except:
93+
if _get_platform() != "windows":
94+
print('[treeplot] >Install graphviz first: <sudo apt install python-pydot python-pydot-ng graphviz>')
95+
96+
# Plot importance
97+
ax2 = None
98+
try:
99+
fig, ax2 = plt.subplots(1, 1, figsize=figsize)
100+
plot_importance(model, max_num_features=50, ax=ax2)
101+
except:
102+
print('[treeplot] >Error: importance can not be plotted. Booster.get_score() results in empty. This maybe caused by having all trees as decision dumps.')
103+
104+
return(ax1, ax2)
105+
67106
# %% Plot tree
68107
def xgboost(model, featnames=None, num_trees=None, plottype='horizontal', figsize=(25,25), verbose=3):
69108
"""Plot tree based on a xgboost.

0 commit comments

Comments
 (0)