Skip to content

Commit 72c11f5

Browse files
committed
verbosity improvements
1 parent d74e6a3 commit 72c11f5

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed

treeplot/treeplot.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def xgboost(model, featnames=None, num_trees=None, plottype='horizontal', figsiz
132132
try:
133133
from xgboost import plot_tree, plot_importance
134134
except:
135-
raise ImportError('xgboost must be installed. Try to: <pip install xgboost>')
135+
if verbose>=1: raise ImportError('xgboost must be installed. Try to: <pip install xgboost>')
136136

137137
_check_model(model, 'xgb')
138138
# Set env
@@ -142,20 +142,27 @@ def xgboost(model, featnames=None, num_trees=None, plottype='horizontal', figsiz
142142
if plottype=='vertical': plottype='LR'
143143
if (num_trees is None) and hasattr(model, 'best_iteration'):
144144
num_trees = model.best_iteration
145+
if verbose>=3: print('[treeplot] >Best detected tree: %.0d' %(num_trees))
145146
elif num_trees is None:
146147
num_trees = 0
147148

149+
ax1 = None
148150
try:
149-
fig, ax = plt.subplots(1, 1, figsize=figsize)
150-
plot_tree(model, num_trees=num_trees, rankdir=plottype, ax=ax)
151+
fig, ax1 = plt.subplots(1, 1, figsize=figsize)
152+
plot_tree(model, num_trees=num_trees, rankdir=plottype, ax=ax1)
151153
except:
152154
if _get_platform() != "windows":
153-
print('[TREEPLOT] Install graphviz first: <sudo apt install python-pydot python-pydot-ng graphviz>')
155+
print('[treeplot] >Install graphviz first: <sudo apt install python-pydot python-pydot-ng graphviz>')
154156

155157
# Plot importance
156-
plot_importance(model)
158+
ax2 = None
159+
try:
160+
fig, ax2 = plt.subplots(1, 1, figsize=figsize)
161+
plot_importance(model, max_num_features=50, ax=ax2)
162+
except:
163+
print('[treeplot] >Error: importance can not be plotted. Booster.get_score() results in empty. This maybe caused by having all trees as decision dumps.')
157164

158-
return(ax)
165+
return(ax1, ax2)
159166

160167

161168
# %% Plot tree
@@ -241,7 +248,7 @@ def randomforest(model, featnames=None, num_trees=None, filepath='tree', export=
241248
plt.show()
242249
except:
243250
if _get_platform() != "windows":
244-
print('[TREEPLOT] Install graphviz first: <sudo apt install python-pydot python-pydot-ng graphviz>')
251+
print('[treeplot] >Install graphviz first: <sudo apt install python-pydot python-pydot-ng graphviz>')
245252
else:
246253
graph = Source(dot_data)
247254
plt.show()
@@ -304,7 +311,7 @@ def _set_graphviz_path(verbose=3):
304311
getZip = os.path.abspath(os.path.join(curpath, gfile))
305312
# Unzip if path does not exists
306313
if not os.path.isdir(getPath):
307-
if verbose>=3: print('[TREEPLOT] Extracting graphviz files..')
314+
if verbose>=3: print('[treeplot] >Extracting graphviz files..')
308315
[pathname, _] = os.path.split(getZip)
309316
# Unzip
310317
zip_ref = zipfile.ZipFile(getZip, 'r')
@@ -323,7 +330,7 @@ def _set_graphviz_path(verbose=3):
323330

324331
# Add to system
325332
if finPath not in os.environ["PATH"]:
326-
if verbose>=3: print('[TREEPLOT] Set path in environment.')
333+
if verbose>=3: print('[treeplot] >Set path in environment.')
327334
os.environ["PATH"] += os.pathsep + finPath
328335

329336
return(finPath)
@@ -349,11 +356,13 @@ def _check_model(model, expected):
349356
if ('forest' in modelname) or ('tree' in modelname) or ('gradientboosting' in modelname):
350357
pass
351358
else:
352-
print('WARNING: The input model seems not to be a tree-based model?')
359+
print('[treeplot] >>Warning: The input model seems not to be a tree-based model?')
353360
if (expected=='xgb'):
354361
if ('xgb' not in modelname):
355-
print('WARNING: The input model seems not to be a xgboost model?')
356-
362+
print('[treeplot] >Warning: The input model seems not to be a xgboost model?')
363+
if (expected=='lgb'):
364+
if ('lgb' not in modelname):
365+
print('[treeplot] >Warning: The input model seems not to be a lightgbm model?')
357366

358367
# %% Import example dataset from github.
359368
def _download_graphviz(url, verbose=3):
@@ -377,13 +386,13 @@ def _download_graphviz(url, verbose=3):
377386
gfile = wget.filename_from_url(url)
378387
PATH_TO_DATA = os.path.join(curpath, gfile)
379388
if not os.path.isdir(curpath):
380-
if verbose>=3: print('[treeplot] Downloading graphviz..')
389+
if verbose>=3: print('[treeplot] >Downloading graphviz..')
381390
os.makedirs(curpath, exist_ok=True)
382391

383392
# Check file exists.
384393
if not os.path.isfile(PATH_TO_DATA):
385394
# Download data from URL
386-
if verbose>=3: print('[treeplot] Downloading graphviz..')
395+
if verbose>=3: print('[treeplot] >Downloading graphviz..')
387396
wget.download(url, curpath)
388397

389398
return(gfile, curpath)

0 commit comments

Comments
 (0)