1818 SHAP_AVAILABLE = True
1919
2020
21- def _check_if_model_is_ready (func ):
21+ def _check_readiness (func ):
2222 @wraps (func )
2323 def wrapper (* args , ** kwargs ):
2424 self = args [0 ]
@@ -29,6 +29,22 @@ def wrapper(*args, **kwargs):
2929 return wrapper
3030
3131
32+ def _check_task (task ):
33+ def actual_decorator (func ):
34+ @wraps (func )
35+ def wrapper (* args , ** kwargs ):
36+ self_task = args [0 ].task_type ()
37+ strict = task .upper () != 'CLASSIFICATION'
38+ target_task = Task (task )
39+ if (strict and (self_task == target_task )) or \
40+ (not strict and (self_task >= target_task )):
41+ return func (* args , ** kwargs )
42+ else :
43+ raise RuntimeError ('This method is not available for {} tasks' .format (self_task .name .lower ()))
44+ return wrapper
45+ return actual_decorator
46+
47+
3248class Task (int ):
3349 _REGRESSION , _CLASSIFICATION = 0 , 1
3450 _BINARY_CLASSIFICATION , _MULTILABEL_CLASSIFICATION = 2 , 3
@@ -67,21 +83,24 @@ class Model(object):
6783 File path of the serialized model. It must be a file that can be
6884 loaded using :mod:`joblib`
6985 """
86+ # Explainable models
87+ _explainable_models = (
88+ # Sklearn
89+ 'DecisionTreeClassifier' , 'DecisionTreeRegressor' ,
90+ 'RandomForestClassifier' , 'RandomForestRegressor' ,
91+ # XGBoost
92+ 'XGBClassifier' , 'XGBRegressor' , 'Booster' ,
93+ # CatBoost
94+ 'CatBoostClassifier' , 'CatBoostRegressor' ,
95+ # LightGBM
96+ 'LGBMClassifier' , 'LGBMRegressor' )
7097
7198 def __init__ (self , file_name ):
72- def get_last_column (X ):
73- return X [:, - 1 ].reshape (- 1 , 1 )
74-
75- setattr (sys .modules ['__main__' ], 'get_last_column' , get_last_column )
7699 self ._file_name = file_name
77100 self ._is_ready = False
78101 self ._model = None
79102 self ._metadata = None
80103 self ._task_type = None
81- # Explainability
82- self ._shap_models = ['Booster' , 'Booster' ]
83- for m in ('RandomForest' , 'XGB' , 'CatBoost' , 'LGBM' , 'DecisionTree' ):
84- self ._shap_models .extend ([m + 'Classifier' , m + 'Regressor' ])
85104
86105 # Private
87106 def _load (self ):
@@ -91,40 +110,37 @@ def _load(self):
91110 self ._metadata = loaded ['metadata' ]
92111 self ._is_ready = True
93112 # Hydrate class
94- cls = self ._get_predictor ()
113+ clf = self ._get_predictor ()
95114 # SHAP
96- is_shap_model = type (cls ).__name__ in self . _shap_models
97- self ._is_explainable = SHAP_AVAILABLE and is_shap_model
115+ model_name = type (clf ).__name__
116+ self ._is_explainable = SHAP_AVAILABLE and ( model_name in self . _explainable_models )
98117 # Feature importances
99- if hasattr (cls , 'feature_importances_' ):
100- importance = cls .feature_importances_
118+ if hasattr (clf , 'feature_importances_' ):
119+ importance = clf .feature_importances_
101120 for imp , feat in zip (importance , loaded ['metadata' ]['features' ]):
102121 feat ['importance' ] = imp
103122 # Set model task type
104- if not hasattr (cls , 'classes_' ):
123+ if not hasattr (clf , 'classes_' ):
105124 self ._task_type = Task ('REGRESSION' )
106- elif len (cls .classes_ ) <= 2 :
125+ elif len (clf .classes_ ) <= 2 :
107126 self ._task_type = Task ('BINARY_CLASSIFICATION' )
108- elif len (cls .classes_ ) > 2 :
127+ elif len (clf .classes_ ) > 2 :
109128 self ._task_type = Task ('MULTILABEL_CLASSIFICATION' )
110129
111- @_check_if_model_is_ready
130+ @_check_readiness
112131 def _get_predictor (self ):
113- model_name = type (self ._model ).__name__
114- if model_name == 'Pipeline' :
115- return self ._model .steps [- 1 ][1 ]
116- else :
117- return self ._model
132+ return Model ._extract_base_predictor (self ._model )
118133
119- @_check_if_model_is_ready
134+ @_check_readiness
135+ @_check_task ('classification' )
120136 def _get_class_names (self ):
121137 return np .array (self ._get_predictor ().classes_ , str )
122138
123- @_check_if_model_is_ready
139+ @_check_readiness
124140 def _feature_names (self ):
125141 return [variable ['name' ] for variable in self .features ()]
126142
127- @_check_if_model_is_ready
143+ @_check_readiness
128144 def _validate (self , input ):
129145 if self .metadata .get ('features' ) is None :
130146 raise AttributeError ("Missing key 'features' in model's metadata" )
@@ -164,22 +180,22 @@ def _validate(self, input):
164180 return df
165181
166182 @property
167- @_check_if_model_is_ready
183+ @_check_readiness
168184 def _is_classification (self ):
169185 return self ._task_type >= Task ('CLASSIFICATION' )
170186
171187 @property
172- @_check_if_model_is_ready
188+ @_check_readiness
173189 def _is_binary_classification (self ):
174190 return self ._task_type == Task ('BINARY_CLASSIFICATION' )
175191
176192 @property
177- @_check_if_model_is_ready
193+ @_check_readiness
178194 def _is_multilabel_classification (self ):
179195 return self ._task_type == Task ('MULTILABEL_CLASSIFICATION' )
180196
181197 @property
182- @_check_if_model_is_ready
198+ @_check_readiness
183199 def _is_regression (self ):
184200 return self ._task_type == Task ('REGRESSION' )
185201
@@ -193,6 +209,16 @@ def _is_listlike(data):
193209 data = [data ]
194210 return is_input_listlike , data
195211
212+ @staticmethod
213+ def _extract_base_predictor (model ):
214+ model_name = type (model ).__name__
215+ if model_name == 'Pipeline' :
216+ return Model ._extract_base_predictor (model .steps [- 1 ][1 ])
217+ elif 'CalibratedClassifier' in model_name :
218+ return Model ._extract_base_predictor (model .base_estimator )
219+ else :
220+ return model
221+
196222 # Public
197223 def load (self ):
198224 """Launch model loading in a separated thread
@@ -218,7 +244,7 @@ def is_ready(self):
218244 return self ._is_ready
219245
220246 @property
221- @_check_if_model_is_ready
247+ @_check_readiness
222248 def metadata (self ):
223249 """Get metadata of the model_name.
224250
@@ -232,7 +258,7 @@ def metadata(self):
232258 """
233259 return self ._metadata
234260
235- @_check_if_model_is_ready
261+ @_check_readiness
236262 def task_type (self , as_text = False ):
237263 """Get task type of the model
238264
@@ -251,7 +277,7 @@ def task_type(self, as_text=False):
251277 """
252278 return self ._task_type .name if as_text else self ._task_type
253279
254- @_check_if_model_is_ready
280+ @_check_readiness
255281 def features (self ):
256282 """Get the features of the model
257283
@@ -270,7 +296,7 @@ def features(self):
270296 return deepcopy (self .metadata ['features' ])
271297
272298 @property
273- @_check_if_model_is_ready
299+ @_check_readiness
274300 def info (self ):
275301 """Get model information.
276302
@@ -316,7 +342,7 @@ def info(self):
316342 result ['model' ]['class_names' ] = self ._get_class_names ()
317343 return result
318344
319- @_check_if_model_is_ready
345+ @_check_readiness
320346 def preprocess (self , input ):
321347 """Preprocess data
322348
@@ -340,7 +366,7 @@ def preprocess(self, input):
340366 else :
341367 return input
342368
343- @_check_if_model_is_ready
369+ @_check_readiness
344370 def predict (self , features ):
345371 """Make a prediciton
346372
@@ -365,7 +391,8 @@ def predict(self, features):
365391 result = self ._model .predict (input )
366392 return result
367393
368- @_check_if_model_is_ready
394+ @_check_readiness
395+ @_check_task ('classification' )
369396 def predict_proba (self , features ):
370397 """Make a prediciton
371398
@@ -383,17 +410,14 @@ def predict_proba(self, features):
383410 dict: Predicted class probabilities.
384411
385412 Raises:
386- RuntimeError: If the model is not ready.
413+ RuntimeError: If the model isn't ready or the task isn't classification .
387414 """
388- # Test for model task
389- if self ._is_regression :
390- raise ValueError ("Can't predict probabilities of regression model" )
391415 input = self ._validate (features )
392416 prediction = self ._model .predict_proba (input )
393417 df = pd .DataFrame (prediction , columns = self ._get_class_names ())
394418 return df .to_dict (orient = 'records' )
395419
396- @_check_if_model_is_ready
420+ @_check_readiness
397421 def explain (self , features , samples = None ):
398422 """Explain the prediction of a model.
399423
0 commit comments