@@ -57,13 +57,52 @@ def plot(model, featnames=None, num_trees=None, plottype='horizontal', figsize=(
5757 elif ('tree' in modelname ) or ('forest' in modelname ) or ('gradientboosting' in modelname ):
5858 if verbose >= 4 : print ('tree plotting pipeline.' )
5959 ax = randomforest (model , featnames = featnames , num_trees = num_trees , figsize = figsize , verbose = verbose )
60+ if ('lgb' in modelname ):
61+ ax = plot_lgb (model , featnames = featnames , num_trees = num_trees , figsize = figsize , verbose = verbose )
6062 else :
61- print ('[treeplot] Model %s not recognized. ' % (modelname ))
63+ print ('[treeplot] > Model not recognized: %s ' % (modelname ))
6264 ax = None
6365
6466 return ax
6567
6668
69+ # %% Plot tree
70+ def plot_lgb (model , featnames = None , num_trees = None , figsize = (25 ,25 ), verbose = 3 ):
71+ try :
72+ from lightgbm import plot_tree , plot_importance
73+ except :
74+ if verbose >= 1 : raise ImportError ('lightgbm must be installed. Try to: <pip install lightgbm>' )
75+ return None
76+
77+ # Check model
78+ _check_model (model , 'lgb' )
79+ # Set env
80+ _set_graphviz_path ()
81+
82+ if (num_trees is None ) and hasattr (model , 'best_iteration_' ):
83+ num_trees = model .best_iteration_
84+ if verbose >= 3 : print ('[treeplot] >Best detected tree: %.0d' % (num_trees ))
85+ elif num_trees is None :
86+ num_trees = 0
87+
88+ ax1 = None
89+ try :
90+ fig , ax1 = plt .subplots (1 , 1 , figsize = figsize )
91+ plot_tree (model , tree_index = num_trees , dpi = 200 , ax = ax1 )
92+ except :
93+ if _get_platform () != "windows" :
94+ print ('[treeplot] >Install graphviz first: <sudo apt install python-pydot python-pydot-ng graphviz>' )
95+
96+ # Plot importance
97+ ax2 = None
98+ try :
99+ fig , ax2 = plt .subplots (1 , 1 , figsize = figsize )
100+ plot_importance (model , max_num_features = 50 , ax = ax2 )
101+ except :
102+ print ('[treeplot] >Error: importance can not be plotted. Booster.get_score() results in empty. This maybe caused by having all trees as decision dumps.' )
103+
104+ return (ax1 , ax2 )
105+
67106# %% Plot tree
68107def xgboost (model , featnames = None , num_trees = None , plottype = 'horizontal' , figsize = (25 ,25 ), verbose = 3 ):
69108 """Plot tree based on a xgboost.
0 commit comments