Skip to content

Commit d57ab1d

Browse files
committed
* ModelFactory with model family
* Generalize json encoder for pandas * Join decorators into one
1 parent 1acaab9 commit d57ab1d

File tree

5 files changed

+66
-65
lines changed

5 files changed

+66
-65
lines changed

python/factory.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44

55
class ModelFactory(object):
6+
available_models = (SklearnModel, )
67

78
@classmethod
89
def create_model(cls, model_name, model_type='SKLEARN_MODEL'):
@@ -17,5 +18,6 @@ def create_model(cls, model_name, model_type='SKLEARN_MODEL'):
1718
raise RuntimeError("Model {} not found".format(model_path))
1819
else:
1920
# Model found! now create an instance
20-
if model_type == 'SKLEARN_MODEL':
21-
return SklearnModel(model_path)
21+
for model_class in cls.available_models:
22+
if model_class.family == model_type:
23+
return model_class(model_path)

python/model/base.py

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,25 @@
1717
SHAP_AVAILABLE = True
1818

1919

20-
def _check_readiness(func):
21-
@wraps(func)
22-
def wrapper(*args, **kwargs):
23-
self = args[0]
24-
if self.is_ready():
25-
return func(*args, **kwargs)
26-
else:
27-
raise RuntimeError('Model is not ready yet.')
28-
return wrapper
29-
30-
31-
def _check_task(task):
20+
def _check(ready=True, explainable=False, task=None):
3221
def actual_decorator(func):
3322
@wraps(func)
3423
def wrapper(*args, **kwargs):
35-
self_task = args[0].task_type()
36-
strict = task.upper() != 'CLASSIFICATION'
37-
target_task = Task(task)
38-
if (strict and (self_task == target_task)) or \
39-
(not strict and (self_task >= target_task)):
40-
return func(*args, **kwargs)
41-
else:
42-
raise RuntimeError('This method is not available for {} tasks'.format(self_task.name.lower()))
24+
self = args[0]
25+
# Check rediness
26+
if ready and not self.is_ready():
27+
raise RuntimeError('Model is not ready yet.')
28+
# Check explainable
29+
if explainable and not self._is_explainable:
30+
model_name = type(self._model).__name__
31+
raise ValueError('Model not supported for explanations: {}'.format(model_name))
32+
# Check for task
33+
if task is not None:
34+
self_task = self.task_type()
35+
if not getattr(self_task, '__ge__' if task.upper() == 'CLASSIFICATION' else '__eq__')(Task(task)):
36+
raise RuntimeError('This method is not available for {} tasks'.format(self_task.name.lower()))
37+
# Execute function
38+
return func(*args, **kwargs)
4339
return wrapper
4440
return actual_decorator
4541

@@ -67,6 +63,7 @@ def __repr__(self):
6763

6864
class BaseModel(object):
6965
"""Base Class that handles the loaded model."""
66+
family = ''
7067
# Explainable models
7168
_explainable_models = tuple()
7269

@@ -76,34 +73,33 @@ def __init__(self, file_name):
7673
self._model = None
7774
self._metadata = None
7875
self._task_type = None
76+
self._is_explainable = False
7977

8078
# Abstract
8179
def _load(self):
8280
raise NotImplementedError()
8381

84-
@_check_readiness
82+
@_check()
8583
def _get_predictor(self):
8684
raise NotImplementedError()
8785

88-
@_check_readiness
89-
@_check_task('classification')
86+
@_check(task='classification')
9087
def _get_class_names(self):
9188
raise NotImplementedError()
9289

93-
@_check_readiness
94-
def preprocess(self, input):
90+
@_check()
91+
def preprocess(self, features):
9592
raise NotImplementedError()
9693

97-
@_check_readiness
94+
@_check()
9895
def predict(self, features):
9996
raise NotImplementedError()
10097

101-
@_check_readiness
102-
@_check_task('classification')
98+
@_check(task='classification')
10399
def predict_proba(self, features):
104100
raise NotImplementedError()
105101

106-
@_check_readiness
102+
@_check(explainable=True)
107103
def explain(self, features, samples=None):
108104
raise NotImplementedError()
109105

@@ -131,11 +127,11 @@ def _hydrate(self, model, metadata):
131127
elif len(clf.classes_) > 2:
132128
self._task_type = Task('MULTILABEL_CLASSIFICATION')
133129

134-
@_check_readiness
130+
@_check()
135131
def _feature_names(self):
136132
return [variable['name'] for variable in self.features()]
137133

138-
@_check_readiness
134+
@_check()
139135
def _validate(self, input):
140136
if self.metadata.get('features') is None:
141137
raise AttributeError("Missing key 'features' in model's metadata")
@@ -175,22 +171,22 @@ def _validate(self, input):
175171
return df
176172

177173
@property
178-
@_check_readiness
174+
@_check()
179175
def _is_classification(self):
180176
return self._task_type >= Task('CLASSIFICATION')
181177

182178
@property
183-
@_check_readiness
179+
@_check()
184180
def _is_binary_classification(self):
185181
return self._task_type == Task('BINARY_CLASSIFICATION')
186182

187183
@property
188-
@_check_readiness
184+
@_check()
189185
def _is_multilabel_classification(self):
190186
return self._task_type == Task('MULTILABEL_CLASSIFICATION')
191187

192188
@property
193-
@_check_readiness
189+
@_check()
194190
def _is_regression(self):
195191
return self._task_type == Task('REGRESSION')
196192

@@ -229,7 +225,7 @@ def is_ready(self):
229225
return self._is_ready
230226

231227
@property
232-
@_check_readiness
228+
@_check()
233229
def metadata(self):
234230
"""Get metadata of the model_name.
235231
@@ -243,7 +239,7 @@ def metadata(self):
243239
"""
244240
return self._metadata
245241

246-
@_check_readiness
242+
@_check()
247243
def task_type(self, as_text=False):
248244
"""Get task type of the model
249245
@@ -262,7 +258,7 @@ def task_type(self, as_text=False):
262258
"""
263259
return self._task_type.name if as_text else self._task_type
264260

265-
@_check_readiness
261+
@_check()
266262
def features(self):
267263
"""Get the features of the model
268264
@@ -281,7 +277,7 @@ def features(self):
281277
return deepcopy(self.metadata['features'])
282278

283279
@property
284-
@_check_readiness
280+
@_check()
285281
def info(self):
286282
"""Get model information.
287283
@@ -321,7 +317,8 @@ def info(self):
321317
'type': str(type(self._model)),
322318
'predictor_type': str(type(self._get_predictor())),
323319
'is_explainable': self._is_explainable,
324-
'task': self.task_type(as_text=True)
320+
'task': self.task_type(as_text=True),
321+
'family': self.family
325322
}
326323
if self._is_classification:
327324
result['model']['class_names'] = self._get_class_names()

python/model/sklearn.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import pandas as pd
44

5-
from .base import BaseModel, Task, _check_task, _check_readiness
5+
from .base import BaseModel, Task, _check
66

77
try:
88
import shap
@@ -27,6 +27,8 @@ class SklearnModel(BaseModel):
2727
File path of the serialized model. It must be a file that can be
2828
loaded using :mod:`joblib`
2929
"""
30+
family = 'SKLEARN_MODEL'
31+
3032
# Explainable models
3133
_explainable_models = (
3234
# Sklearn
@@ -45,12 +47,11 @@ def _load(self):
4547
loaded = joblib.load(self._file_name)
4648
self._hydrate(loaded['model'], loaded['metadata'])
4749

48-
@_check_readiness
50+
@_check()
4951
def _get_predictor(self):
5052
return SklearnModel._extract_base_predictor(self._model)
5153

52-
@_check_readiness
53-
@_check_task('classification')
54+
@_check(task='classification')
5455
def _get_class_names(self):
5556
return np.array(self._get_predictor().classes_, str)
5657

@@ -65,14 +66,15 @@ def _extract_base_predictor(model):
6566
else:
6667
return model
6768

68-
@_check_readiness
69-
def preprocess(self, input):
69+
# Public
70+
@_check()
71+
def preprocess(self, features):
7072
"""Preprocess data
7173
7274
This function is used before prediction or interpretation.
7375
7476
Args:
75-
input (dict):
77+
features (dict):
7678
The expected object must contain one key per feature.
7779
Example: `{'feature1': 5, 'feature2': 'A', 'feature3': 10}`
7880
@@ -84,12 +86,13 @@ def preprocess(self, input):
8486
Raises:
8587
RuntimeError: If the model is not ready.
8688
"""
89+
input = self._validate(features)
8790
if hasattr(self._model, 'transform'):
8891
return self._model.transform(input)
8992
else:
9093
return input
9194

92-
@_check_readiness
95+
@_check()
9396
def predict(self, features):
9497
"""Make a prediciton
9598
@@ -114,8 +117,7 @@ def predict(self, features):
114117
result = self._model.predict(input)
115118
return result
116119

117-
@_check_readiness
118-
@_check_task('classification')
120+
@_check(task='classification')
119121
def predict_proba(self, features):
120122
"""Make a prediciton
121123
@@ -140,7 +142,7 @@ def predict_proba(self, features):
140142
df = pd.DataFrame(prediction, columns=self._get_class_names())
141143
return df.to_dict(orient='records')
142144

143-
@_check_readiness
145+
@_check(explainable=True)
144146
def explain(self, features, samples=None):
145147
"""Explain the prediction of a model.
146148
@@ -173,13 +175,8 @@ def explain(self, features, samples=None):
173175
explanations or the model is not already loaded.
174176
Or if the explainer outputs an unknown object
175177
"""
176-
if not self._is_explainable:
177-
model_name = type(self._model).__name__
178-
msg = 'Model not supported for explanations: {}'.format(model_name)
179-
raise ValueError(msg)
180178
# Process input
181-
input = self._validate(features)
182-
preprocessed = self.preprocess(input)
179+
preprocessed = self.preprocess(features)
183180
# Define parameters
184181
if samples is None:
185182
params = {

python/utils/encoder.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
import flask
22
import json
33
import numpy as np
4+
import pandas as pd
45

56
from functools import wraps
67

78

8-
class NumpyEncoder(flask.json.JSONEncoder):
9-
"""Encoder of numpy primitives into JSON strings"""
9+
class ExtendedEncoder(flask.json.JSONEncoder):
10+
"""Encoder of numpy primitives and Pandas objects into JSON strings"""
1011
primitives = (np.ndarray, np.integer, np.inexact)
1112

1213
def default(self, obj):
1314
if isinstance(obj, np.flexible):
1415
return None if isinstance(obj, np.void) else obj.tolist()
1516
elif isinstance(obj, self.primitives):
1617
return obj.tolist()
18+
elif isinstance(obj, pd.DataFrame):
19+
return obj.to_dict('records')
20+
elif isinstance(obj, pd.Series):
21+
return json.JSONEncoder.default(self, obj.to_frame())
1722
return json.JSONEncoder.default(self, obj)
1823

1924

@@ -26,6 +31,6 @@ def decorated_function(*args, **kwargs):
2631
if isinstance(r, flask.Response):
2732
return r
2833
else:
29-
return flask.Response(json.dumps(r, cls=NumpyEncoder), status=200,
34+
return flask.Response(json.dumps(r, cls=ExtendedEncoder), status=200,
3035
mimetype='application/json; charset=utf-8')
3136
return decorated_function

service.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from time import time
1313

14-
from python.utils.encoder import NumpyEncoder, returns_json
14+
from python.utils.encoder import ExtendedEncoder, returns_json
1515
from python.factory import ModelFactory
1616

1717
# Version of this APP template
@@ -26,7 +26,7 @@
2626
app = flask.Flask(__name__)
2727
# Customize Flask Application
2828
app.logger.setLevel(logging.DEBUG if DEBUG else logging.ERROR)
29-
app.json_encoder = NumpyEncoder
29+
app.json_encoder = ExtendedEncoder
3030
# Create Model instance
3131
model = ModelFactory.create_model(MODEL_NAME, MODEL_TYPE)
3232
# laod saved model
@@ -183,7 +183,7 @@ def service_info():
183183
'version-template': __version__,
184184
'running-since': SERVICE_START_TIMESTAMP,
185185
'serving-model-file': MODEL_NAME,
186-
'serving-model-type': MODEL_TYPE,
186+
'serving-model-family': model.family,
187187
'debug': DEBUG}
188188
return info
189189

0 commit comments

Comments
 (0)