Skip to content

Commit 8cc787f

Browse files
committed
Move model wrapper into its own module
1 parent 067adfa commit 8cc787f

File tree

2 files changed

+210
-186
lines changed

2 files changed

+210
-186
lines changed

deepaas/model/v2/__init__.py

Lines changed: 7 additions & 186 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,11 @@
1414
# License for the specific language governing permissions and limitations
1515
# under the License.
1616

17-
from aiohttp import web
18-
import marshmallow
1917
from oslo_log import log
2018

2119
from deepaas.model import loading
2220
from deepaas.model.v2 import test
21+
from deepaas.model.v2 import wrapper
2322

2423
LOG = log.getLogger(__name__)
2524

@@ -37,7 +36,7 @@ def register_models():
3736

3837
try:
3938
for name, model in loading.get_available_models("v2").items():
40-
MODELS[name] = ModelWrapper(name, model)
39+
MODELS[name] = wrapper.ModelWrapper(name, model)
4140
except Exception as e:
4241
LOG.warning("Error loading models: %s", e)
4342

@@ -53,192 +52,14 @@ def register_models():
5352

5453
try:
5554
for name, model in loading.get_available_models("v1").items():
56-
MODELS[name] = ModelWrapper(name, model)
55+
MODELS[name] = wrapper.ModelWrapper(name, model)
5756
except Exception as e:
5857
LOG.warning("Error loading models: %s", e)
5958

6059
if not MODELS:
6160
LOG.info("No models found with V2 or V1 namespace, loading test model")
62-
MODELS["deepaas-test"] = ModelWrapper("deepaas-test", test.TestModel())
61+
MODELS["deepaas-test"] = wrapper.ModelWrapper(
62+
"deepaas-test",
63+
test.TestModel()
64+
)
6365
MODELS_LOADED = True
64-
65-
66-
def catch_error(f):
67-
"""Decorator to catch errors when executing the underlying methods."""
68-
69-
def wrap(*args, **kwargs):
70-
name = args[0].name
71-
try:
72-
return f(*args, **kwargs)
73-
except AttributeError:
74-
raise web.HTTPNotImplemented(
75-
reason=("Not implemented by underlying model (loaded '%s')" %
76-
name)
77-
)
78-
except NotImplementedError:
79-
raise web.HTTPNotImplemented(
80-
reason=("Model '%s' does not implement this functionality" %
81-
name)
82-
)
83-
except Exception as e:
84-
LOG.error("An exception has happened when calling '%s' method on "
85-
"'%s' model." % (f, name))
86-
LOG.exception(e)
87-
if isinstance(e, web.HTTPException):
88-
raise e
89-
else:
90-
raise web.HTTPInternalServerError(reason=e)
91-
return wrap
92-
93-
94-
class ModelWrapper(object):
95-
"""Class that will wrap the loaded models before exposing them.
96-
97-
Whenever a model is loaded it will be wrapped with this class to create a
98-
wrapper object that will handle the calls to the model's methods so as to
99-
handle non-existent method exceptions.
100-
101-
:param name: Model name
102-
:param model: Model object
103-
:raises HTTPInternalServerError: in case that a model has defined
104-
a reponse schema that is nod JSON schema valid (DRAFT 4)
105-
"""
106-
def __init__(self, name, model_obj):
107-
self.name = name
108-
self.model_obj = model_obj
109-
110-
schema = getattr(self.model_obj, "schema", None)
111-
112-
if isinstance(schema, dict):
113-
try:
114-
schema = marshmallow.Schema.from_dict(
115-
schema,
116-
name="ModelPredictionResponse"
117-
)
118-
self.has_schema = True
119-
except Exception as e:
120-
LOG.exception(e)
121-
raise web.HTTPInternalServerError(
122-
reason=("Model defined schema is invalid, "
123-
"check server logs.")
124-
)
125-
elif schema is not None:
126-
try:
127-
if issubclass(schema, marshmallow.Schema):
128-
self.has_schema = True
129-
except TypeError:
130-
raise web.HTTPInternalServerError(
131-
reason=("Model defined schema is invalid, "
132-
"check server logs.")
133-
)
134-
else:
135-
self.has_schema = False
136-
137-
self.response_schema = schema
138-
139-
def validate_response(self, response):
140-
"""Validate a response against the model's response schema, if set.
141-
142-
If the wrapped model has defined a ``response`` attribute we will
143-
validate the response that
144-
145-
:param response: The reponse that will be validated.
146-
:raises exceptions.InternalServerError: in case the reponse cannot be
147-
validated.
148-
"""
149-
if self.has_schema is not True:
150-
raise web.HTTPInternalServerError(
151-
reason=("Trying to validate against a schema, but I do not "
152-
"have one defined")
153-
)
154-
155-
try:
156-
self.response_schema().load(response)
157-
except marshmallow.ValidationError as e:
158-
LOG.exception(e)
159-
raise web.HTTPInternalServerError(
160-
reason="ERROR validating model response, check server logs."
161-
)
162-
except Exception as e:
163-
LOG.exception(e)
164-
raise web.HTTPInternalServerError(
165-
reason="Unknown ERROR validating response, check server logs."
166-
)
167-
168-
return True
169-
170-
def get_metadata(self):
171-
"""Obtain model's metadata.
172-
173-
If the model's metadata cannot be obtained because it is not
174-
implemented, we will provide some generic information so that the
175-
call does not fail.
176-
177-
:returns dict: dictionary containing model's metadata
178-
"""
179-
try:
180-
d = self.model_obj.get_metadata()
181-
except (NotImplementedError, AttributeError):
182-
d = {
183-
"id": "0",
184-
"name": self.name,
185-
"description": ("Could not load description from "
186-
"underlying model (loaded '%s')" % self.name),
187-
}
188-
return d
189-
190-
@catch_error
191-
def predict(self, **kwargs):
192-
"""Perform a prediction on wrapped model's ``predict`` method.
193-
194-
:raises HTTPNotImplemented: If the method is not
195-
implemented in the wrapper model.
196-
:raises HTTPInternalServerError: If the call produces
197-
an error
198-
:raises HTTPException: If the call produces an
199-
error, already wrapped as a HTTPException
200-
"""
201-
return self.model_obj.predict(**kwargs)
202-
203-
@catch_error
204-
def train(self, *args, **kwargs):
205-
"""Perform a training on wrapped model's ``train`` method.
206-
207-
:raises HTTPNotImplemented: If the method is not
208-
implemented in the wrapper model.
209-
:raises HTTPInternalServerError: If the call produces
210-
an error
211-
:raises HTTPException: If the call produces an
212-
error, already wrapped as a HTTPException
213-
"""
214-
return self.model_obj.train(*args, **kwargs)
215-
216-
@catch_error
217-
def get_train_args(self):
218-
"""Add training arguments into the training parser.
219-
220-
:param parser: an argparse like object
221-
222-
This method will call the wrapped model ``add_training_args``. If the
223-
method does not exist, but the wrapped model implements the DEPRECATED
224-
``get_train_args`` we will try to load the arguments from there.
225-
"""
226-
try:
227-
return self.model_obj.get_train_args()
228-
except (NotImplementedError, AttributeError):
229-
return {}
230-
231-
@catch_error
232-
def get_predict_args(self):
233-
"""Add predict arguments into the predict parser.
234-
235-
:param parser: an argparse like object
236-
237-
This method will call the wrapped model ``add_predict_args``. If the
238-
method does not exist, but the wrapped model implements the DEPRECATED
239-
``get_predict_args`` we will try to load the arguments from there.
240-
"""
241-
try:
242-
return self.model_obj.get_predict_args()
243-
except (NotImplementedError, AttributeError):
244-
return {}

0 commit comments

Comments
 (0)