@@ -132,7 +132,7 @@ def xgboost(model, featnames=None, num_trees=None, plottype='horizontal', figsiz
132132 try :
133133 from xgboost import plot_tree , plot_importance
134134 except :
135- raise ImportError ('xgboost must be installed. Try to: <pip install xgboost>' )
135+ if verbose >= 1 : raise ImportError ('xgboost must be installed. Try to: <pip install xgboost>' )
136136
137137 _check_model (model , 'xgb' )
138138 # Set env
@@ -142,20 +142,27 @@ def xgboost(model, featnames=None, num_trees=None, plottype='horizontal', figsiz
142142 if plottype == 'vertical' : plottype = 'LR'
143143 if (num_trees is None ) and hasattr (model , 'best_iteration' ):
144144 num_trees = model .best_iteration
145+ if verbose >= 3 : print ('[treeplot] >Best detected tree: %.0d' % (num_trees ))
145146 elif num_trees is None :
146147 num_trees = 0
147148
149+ ax1 = None
148150 try :
149- fig , ax = plt .subplots (1 , 1 , figsize = figsize )
150- plot_tree (model , num_trees = num_trees , rankdir = plottype , ax = ax )
151+ fig , ax1 = plt .subplots (1 , 1 , figsize = figsize )
152+ plot_tree (model , num_trees = num_trees , rankdir = plottype , ax = ax1 )
151153 except :
152154 if _get_platform () != "windows" :
153- print ('[TREEPLOT] Install graphviz first: <sudo apt install python-pydot python-pydot-ng graphviz>' )
155+ print ('[treeplot] > Install graphviz first: <sudo apt install python-pydot python-pydot-ng graphviz>' )
154156
155157 # Plot importance
156- plot_importance (model )
158+ ax2 = None
159+ try :
160+ fig , ax2 = plt .subplots (1 , 1 , figsize = figsize )
161+ plot_importance (model , max_num_features = 50 , ax = ax2 )
162+ except :
163+ print ('[treeplot] >Error: importance can not be plotted. Booster.get_score() results in empty. This maybe caused by having all trees as decision dumps.' )
157164
158- return (ax )
165+ return (ax1 , ax2 )
159166
160167
161168# %% Plot tree
@@ -241,7 +248,7 @@ def randomforest(model, featnames=None, num_trees=None, filepath='tree', export=
241248 plt .show ()
242249 except :
243250 if _get_platform () != "windows" :
244- print ('[TREEPLOT] Install graphviz first: <sudo apt install python-pydot python-pydot-ng graphviz>' )
251+ print ('[treeplot] > Install graphviz first: <sudo apt install python-pydot python-pydot-ng graphviz>' )
245252 else :
246253 graph = Source (dot_data )
247254 plt .show ()
@@ -304,7 +311,7 @@ def _set_graphviz_path(verbose=3):
304311 getZip = os .path .abspath (os .path .join (curpath , gfile ))
305312 # Unzip if path does not exists
306313 if not os .path .isdir (getPath ):
307- if verbose >= 3 : print ('[TREEPLOT] Extracting graphviz files..' )
314+ if verbose >= 3 : print ('[treeplot] > Extracting graphviz files..' )
308315 [pathname , _ ] = os .path .split (getZip )
309316 # Unzip
310317 zip_ref = zipfile .ZipFile (getZip , 'r' )
@@ -323,7 +330,7 @@ def _set_graphviz_path(verbose=3):
323330
324331 # Add to system
325332 if finPath not in os .environ ["PATH" ]:
326- if verbose >= 3 : print ('[TREEPLOT] Set path in environment.' )
333+ if verbose >= 3 : print ('[treeplot] > Set path in environment.' )
327334 os .environ ["PATH" ] += os .pathsep + finPath
328335
329336 return (finPath )
@@ -349,11 +356,13 @@ def _check_model(model, expected):
349356 if ('forest' in modelname ) or ('tree' in modelname ) or ('gradientboosting' in modelname ):
350357 pass
351358 else :
352- print ('WARNING : The input model seems not to be a tree-based model?' )
359+ print ('[treeplot] >>Warning : The input model seems not to be a tree-based model?' )
353360 if (expected == 'xgb' ):
354361 if ('xgb' not in modelname ):
355- print ('WARNING: The input model seems not to be a xgboost model?' )
356-
362+ print ('[treeplot] >Warning: The input model seems not to be a xgboost model?' )
363+ if (expected == 'lgb' ):
364+ if ('lgb' not in modelname ):
365+ print ('[treeplot] >Warning: The input model seems not to be a lightgbm model?' )
357366
358367# %% Import example dataset from github.
359368def _download_graphviz (url , verbose = 3 ):
@@ -377,13 +386,13 @@ def _download_graphviz(url, verbose=3):
377386 gfile = wget .filename_from_url (url )
378387 PATH_TO_DATA = os .path .join (curpath , gfile )
379388 if not os .path .isdir (curpath ):
380- if verbose >= 3 : print ('[treeplot] Downloading graphviz..' )
389+ if verbose >= 3 : print ('[treeplot] > Downloading graphviz..' )
381390 os .makedirs (curpath , exist_ok = True )
382391
383392 # Check file exists.
384393 if not os .path .isfile (PATH_TO_DATA ):
385394 # Download data from URL
386- if verbose >= 3 : print ('[treeplot] Downloading graphviz..' )
395+ if verbose >= 3 : print ('[treeplot] > Downloading graphviz..' )
387396 wget .download (url , curpath )
388397
389398 return (gfile , curpath )
0 commit comments