Skip to content

Commit ad21c85

Browse files
committed
* Take into account calibrated classifiers
* Rename decorators * Fix list of explainable models * remove unnecessary code
1 parent 09dc3d0 commit ad21c85

File tree

1 file changed

+66
-42
lines changed

1 file changed

+66
-42
lines changed

model.py

Lines changed: 66 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
SHAP_AVAILABLE = True
1919

2020

21-
def _check_if_model_is_ready(func):
21+
def _check_readiness(func):
2222
@wraps(func)
2323
def wrapper(*args, **kwargs):
2424
self = args[0]
@@ -29,6 +29,22 @@ def wrapper(*args, **kwargs):
2929
return wrapper
3030

3131

32+
def _check_task(task):
33+
def actual_decorator(func):
34+
@wraps(func)
35+
def wrapper(*args, **kwargs):
36+
self_task = args[0].task_type()
37+
strict = task.upper() != 'CLASSIFICATION'
38+
target_task = Task(task)
39+
if (strict and (self_task == target_task)) or \
40+
(not strict and (self_task >= target_task)):
41+
return func(*args, **kwargs)
42+
else:
43+
raise RuntimeError('This method is not available for {} tasks'.format(self_task.name.lower()))
44+
return wrapper
45+
return actual_decorator
46+
47+
3248
class Task(int):
3349
_REGRESSION, _CLASSIFICATION = 0, 1
3450
_BINARY_CLASSIFICATION, _MULTILABEL_CLASSIFICATION = 2, 3
@@ -67,21 +83,24 @@ class Model(object):
6783
File path of the serialized model. It must be a file that can be
6884
loaded using :mod:`joblib`
6985
"""
86+
# Explainable models
87+
_explainable_models = (
88+
# Sklearn
89+
'DecisionTreeClassifier', 'DecisionTreeRegressor',
90+
'RandomForestClassifier', 'RandomForestRegressor',
91+
# XGBoost
92+
'XGBClassifier', 'XGBRegressor', 'Booster',
93+
# CatBoost
94+
'CatBoostClassifier', 'CatBoostRegressor',
95+
# LightGBM
96+
'LGBMClassifier', 'LGBMRegressor')
7097

7198
def __init__(self, file_name):
72-
def get_last_column(X):
73-
return X[:, -1].reshape(-1, 1)
74-
75-
setattr(sys.modules['__main__'], 'get_last_column', get_last_column)
7699
self._file_name = file_name
77100
self._is_ready = False
78101
self._model = None
79102
self._metadata = None
80103
self._task_type = None
81-
# Explainability
82-
self._shap_models = ['Booster', 'Booster']
83-
for m in ('RandomForest', 'XGB', 'CatBoost', 'LGBM', 'DecisionTree'):
84-
self._shap_models.extend([m + 'Classifier', m + 'Regressor'])
85104

86105
# Private
87106
def _load(self):
@@ -91,40 +110,37 @@ def _load(self):
91110
self._metadata = loaded['metadata']
92111
self._is_ready = True
93112
# Hydrate class
94-
cls = self._get_predictor()
113+
clf = self._get_predictor()
95114
# SHAP
96-
is_shap_model = type(cls).__name__ in self._shap_models
97-
self._is_explainable = SHAP_AVAILABLE and is_shap_model
115+
model_name = type(clf).__name__
116+
self._is_explainable = SHAP_AVAILABLE and (model_name in self._explainable_models)
98117
# Feature importances
99-
if hasattr(cls, 'feature_importances_'):
100-
importance = cls.feature_importances_
118+
if hasattr(clf, 'feature_importances_'):
119+
importance = clf.feature_importances_
101120
for imp, feat in zip(importance, loaded['metadata']['features']):
102121
feat['importance'] = imp
103122
# Set model task type
104-
if not hasattr(cls, 'classes_'):
123+
if not hasattr(clf, 'classes_'):
105124
self._task_type = Task('REGRESSION')
106-
elif len(cls.classes_) <= 2:
125+
elif len(clf.classes_) <= 2:
107126
self._task_type = Task('BINARY_CLASSIFICATION')
108-
elif len(cls.classes_) > 2:
127+
elif len(clf.classes_) > 2:
109128
self._task_type = Task('MULTILABEL_CLASSIFICATION')
110129

111-
@_check_if_model_is_ready
130+
@_check_readiness
112131
def _get_predictor(self):
113-
model_name = type(self._model).__name__
114-
if model_name == 'Pipeline':
115-
return self._model.steps[-1][1]
116-
else:
117-
return self._model
132+
return Model._extract_base_predictor(self._model)
118133

119-
@_check_if_model_is_ready
134+
@_check_readiness
135+
@_check_task('classification')
120136
def _get_class_names(self):
121137
return np.array(self._get_predictor().classes_, str)
122138

123-
@_check_if_model_is_ready
139+
@_check_readiness
124140
def _feature_names(self):
125141
return [variable['name'] for variable in self.features()]
126142

127-
@_check_if_model_is_ready
143+
@_check_readiness
128144
def _validate(self, input):
129145
if self.metadata.get('features') is None:
130146
raise AttributeError("Missing key 'features' in model's metadata")
@@ -164,22 +180,22 @@ def _validate(self, input):
164180
return df
165181

166182
@property
167-
@_check_if_model_is_ready
183+
@_check_readiness
168184
def _is_classification(self):
169185
return self._task_type >= Task('CLASSIFICATION')
170186

171187
@property
172-
@_check_if_model_is_ready
188+
@_check_readiness
173189
def _is_binary_classification(self):
174190
return self._task_type == Task('BINARY_CLASSIFICATION')
175191

176192
@property
177-
@_check_if_model_is_ready
193+
@_check_readiness
178194
def _is_multilabel_classification(self):
179195
return self._task_type == Task('MULTILABEL_CLASSIFICATION')
180196

181197
@property
182-
@_check_if_model_is_ready
198+
@_check_readiness
183199
def _is_regression(self):
184200
return self._task_type == Task('REGRESSION')
185201

@@ -193,6 +209,16 @@ def _is_listlike(data):
193209
data = [data]
194210
return is_input_listlike, data
195211

212+
@staticmethod
213+
def _extract_base_predictor(model):
214+
model_name = type(model).__name__
215+
if model_name == 'Pipeline':
216+
return Model._extract_base_predictor(model.steps[-1][1])
217+
elif 'CalibratedClassifier' in model_name:
218+
return Model._extract_base_predictor(model.base_estimator)
219+
else:
220+
return model
221+
196222
# Public
197223
def load(self):
198224
"""Launch model loading in a separated thread
@@ -218,7 +244,7 @@ def is_ready(self):
218244
return self._is_ready
219245

220246
@property
221-
@_check_if_model_is_ready
247+
@_check_readiness
222248
def metadata(self):
223249
"""Get metadata of the model_name.
224250
@@ -232,7 +258,7 @@ def metadata(self):
232258
"""
233259
return self._metadata
234260

235-
@_check_if_model_is_ready
261+
@_check_readiness
236262
def task_type(self, as_text=False):
237263
"""Get task type of the model
238264
@@ -251,7 +277,7 @@ def task_type(self, as_text=False):
251277
"""
252278
return self._task_type.name if as_text else self._task_type
253279

254-
@_check_if_model_is_ready
280+
@_check_readiness
255281
def features(self):
256282
"""Get the features of the model
257283
@@ -270,7 +296,7 @@ def features(self):
270296
return deepcopy(self.metadata['features'])
271297

272298
@property
273-
@_check_if_model_is_ready
299+
@_check_readiness
274300
def info(self):
275301
"""Get model information.
276302
@@ -316,7 +342,7 @@ def info(self):
316342
result['model']['class_names'] = self._get_class_names()
317343
return result
318344

319-
@_check_if_model_is_ready
345+
@_check_readiness
320346
def preprocess(self, input):
321347
"""Preprocess data
322348
@@ -340,7 +366,7 @@ def preprocess(self, input):
340366
else:
341367
return input
342368

343-
@_check_if_model_is_ready
369+
@_check_readiness
344370
def predict(self, features):
345371
"""Make a prediciton
346372
@@ -365,7 +391,8 @@ def predict(self, features):
365391
result = self._model.predict(input)
366392
return result
367393

368-
@_check_if_model_is_ready
394+
@_check_readiness
395+
@_check_task('classification')
369396
def predict_proba(self, features):
370397
"""Make a prediciton
371398
@@ -383,17 +410,14 @@ def predict_proba(self, features):
383410
dict: Predicted class probabilities.
384411
385412
Raises:
386-
RuntimeError: If the model is not ready.
413+
RuntimeError: If the model isn't ready or the task isn't classification.
387414
"""
388-
# Test for model task
389-
if self._is_regression:
390-
raise ValueError("Can't predict probabilities of regression model")
391415
input = self._validate(features)
392416
prediction = self._model.predict_proba(input)
393417
df = pd.DataFrame(prediction, columns=self._get_class_names())
394418
return df.to_dict(orient='records')
395419

396-
@_check_if_model_is_ready
420+
@_check_readiness
397421
def explain(self, features, samples=None):
398422
"""Explain the prediction of a model.
399423

0 commit comments

Comments
 (0)