1717 SHAP_AVAILABLE = True
1818
1919
20- def _check_readiness (func ):
21- @wraps (func )
22- def wrapper (* args , ** kwargs ):
23- self = args [0 ]
24- if self .is_ready ():
25- return func (* args , ** kwargs )
26- else :
27- raise RuntimeError ('Model is not ready yet.' )
28- return wrapper
29-
30-
31- def _check_task (task ):
20+ def _check (ready = True , explainable = False , task = None ):
3221 def actual_decorator (func ):
3322 @wraps (func )
3423 def wrapper (* args , ** kwargs ):
35- self_task = args [0 ].task_type ()
36- strict = task .upper () != 'CLASSIFICATION'
37- target_task = Task (task )
38- if (strict and (self_task == target_task )) or \
39- (not strict and (self_task >= target_task )):
40- return func (* args , ** kwargs )
41- else :
42- raise RuntimeError ('This method is not available for {} tasks' .format (self_task .name .lower ()))
24+ self = args [0 ]
25+ # Check rediness
26+ if ready and not self .is_ready ():
27+ raise RuntimeError ('Model is not ready yet.' )
28+ # Check explainable
29+ if explainable and not self ._is_explainable :
30+ model_name = type (self ._model ).__name__
31+ raise ValueError ('Model not supported for explanations: {}' .format (model_name ))
32+ # Check for task
33+ if task is not None :
34+ self_task = self .task_type ()
35+ if not getattr (self_task , '__ge__' if task .upper () == 'CLASSIFICATION' else '__eq__' )(Task (task )):
36+ raise RuntimeError ('This method is not available for {} tasks' .format (self_task .name .lower ()))
37+ # Execute function
38+ return func (* args , ** kwargs )
4339 return wrapper
4440 return actual_decorator
4541
@@ -67,6 +63,7 @@ def __repr__(self):
6763
6864class BaseModel (object ):
6965 """Base Class that handles the loaded model."""
66+ family = ''
7067 # Explainable models
7168 _explainable_models = tuple ()
7269
@@ -76,34 +73,33 @@ def __init__(self, file_name):
7673 self ._model = None
7774 self ._metadata = None
7875 self ._task_type = None
76+ self ._is_explainable = False
7977
8078 # Abstract
8179 def _load (self ):
8280 raise NotImplementedError ()
8381
84- @_check_readiness
82+ @_check ()
8583 def _get_predictor (self ):
8684 raise NotImplementedError ()
8785
88- @_check_readiness
89- @_check_task ('classification' )
86+ @_check (task = 'classification' )
9087 def _get_class_names (self ):
9188 raise NotImplementedError ()
9289
93- @_check_readiness
94- def preprocess (self , input ):
90+ @_check ()
91+ def preprocess (self , features ):
9592 raise NotImplementedError ()
9693
97- @_check_readiness
94+ @_check ()
9895 def predict (self , features ):
9996 raise NotImplementedError ()
10097
101- @_check_readiness
102- @_check_task ('classification' )
98+ @_check (task = 'classification' )
10399 def predict_proba (self , features ):
104100 raise NotImplementedError ()
105101
106- @_check_readiness
102+ @_check ( explainable = True )
107103 def explain (self , features , samples = None ):
108104 raise NotImplementedError ()
109105
@@ -131,11 +127,11 @@ def _hydrate(self, model, metadata):
131127 elif len (clf .classes_ ) > 2 :
132128 self ._task_type = Task ('MULTILABEL_CLASSIFICATION' )
133129
134- @_check_readiness
130+ @_check ()
135131 def _feature_names (self ):
136132 return [variable ['name' ] for variable in self .features ()]
137133
138- @_check_readiness
134+ @_check ()
139135 def _validate (self , input ):
140136 if self .metadata .get ('features' ) is None :
141137 raise AttributeError ("Missing key 'features' in model's metadata" )
@@ -175,22 +171,22 @@ def _validate(self, input):
175171 return df
176172
177173 @property
178- @_check_readiness
174+ @_check ()
179175 def _is_classification (self ):
180176 return self ._task_type >= Task ('CLASSIFICATION' )
181177
182178 @property
183- @_check_readiness
179+ @_check ()
184180 def _is_binary_classification (self ):
185181 return self ._task_type == Task ('BINARY_CLASSIFICATION' )
186182
187183 @property
188- @_check_readiness
184+ @_check ()
189185 def _is_multilabel_classification (self ):
190186 return self ._task_type == Task ('MULTILABEL_CLASSIFICATION' )
191187
192188 @property
193- @_check_readiness
189+ @_check ()
194190 def _is_regression (self ):
195191 return self ._task_type == Task ('REGRESSION' )
196192
@@ -229,7 +225,7 @@ def is_ready(self):
229225 return self ._is_ready
230226
231227 @property
232- @_check_readiness
228+ @_check ()
233229 def metadata (self ):
234230 """Get metadata of the model_name.
235231
@@ -243,7 +239,7 @@ def metadata(self):
243239 """
244240 return self ._metadata
245241
246- @_check_readiness
242+ @_check ()
247243 def task_type (self , as_text = False ):
248244 """Get task type of the model
249245
@@ -262,7 +258,7 @@ def task_type(self, as_text=False):
262258 """
263259 return self ._task_type .name if as_text else self ._task_type
264260
265- @_check_readiness
261+ @_check ()
266262 def features (self ):
267263 """Get the features of the model
268264
@@ -281,7 +277,7 @@ def features(self):
281277 return deepcopy (self .metadata ['features' ])
282278
283279 @property
284- @_check_readiness
280+ @_check ()
285281 def info (self ):
286282 """Get model information.
287283
@@ -321,7 +317,8 @@ def info(self):
321317 'type' : str (type (self ._model )),
322318 'predictor_type' : str (type (self ._get_predictor ())),
323319 'is_explainable' : self ._is_explainable ,
324- 'task' : self .task_type (as_text = True )
320+ 'task' : self .task_type (as_text = True ),
321+ 'family' : self .family
325322 }
326323 if self ._is_classification :
327324 result ['model' ]['class_names' ] = self ._get_class_names ()
0 commit comments