14
14
# License for the specific language governing permissions and limitations
15
15
# under the License.
16
16
17
- from aiohttp import web
18
- import marshmallow
19
17
from oslo_log import log
20
18
21
19
from deepaas .model import loading
22
20
from deepaas .model .v2 import test
21
+ from deepaas .model .v2 import wrapper
23
22
24
23
LOG = log .getLogger (__name__ )
25
24
@@ -37,7 +36,7 @@ def register_models():
37
36
38
37
try :
39
38
for name , model in loading .get_available_models ("v2" ).items ():
40
- MODELS [name ] = ModelWrapper (name , model )
39
+ MODELS [name ] = wrapper . ModelWrapper (name , model )
41
40
except Exception as e :
42
41
LOG .warning ("Error loading models: %s" , e )
43
42
@@ -53,192 +52,14 @@ def register_models():
53
52
54
53
try :
55
54
for name , model in loading .get_available_models ("v1" ).items ():
56
- MODELS [name ] = ModelWrapper (name , model )
55
+ MODELS [name ] = wrapper . ModelWrapper (name , model )
57
56
except Exception as e :
58
57
LOG .warning ("Error loading models: %s" , e )
59
58
60
59
if not MODELS :
61
60
LOG .info ("No models found with V2 or V1 namespace, loading test model" )
62
- MODELS ["deepaas-test" ] = ModelWrapper ("deepaas-test" , test .TestModel ())
61
+ MODELS ["deepaas-test" ] = wrapper .ModelWrapper (
62
+ "deepaas-test" ,
63
+ test .TestModel ()
64
+ )
63
65
MODELS_LOADED = True
64
-
65
-
66
- def catch_error (f ):
67
- """Decorator to catch errors when executing the underlying methods."""
68
-
69
- def wrap (* args , ** kwargs ):
70
- name = args [0 ].name
71
- try :
72
- return f (* args , ** kwargs )
73
- except AttributeError :
74
- raise web .HTTPNotImplemented (
75
- reason = ("Not implemented by underlying model (loaded '%s')" %
76
- name )
77
- )
78
- except NotImplementedError :
79
- raise web .HTTPNotImplemented (
80
- reason = ("Model '%s' does not implement this functionality" %
81
- name )
82
- )
83
- except Exception as e :
84
- LOG .error ("An exception has happened when calling '%s' method on "
85
- "'%s' model." % (f , name ))
86
- LOG .exception (e )
87
- if isinstance (e , web .HTTPException ):
88
- raise e
89
- else :
90
- raise web .HTTPInternalServerError (reason = e )
91
- return wrap
92
-
93
-
94
- class ModelWrapper (object ):
95
- """Class that will wrap the loaded models before exposing them.
96
-
97
- Whenever a model is loaded it will be wrapped with this class to create a
98
- wrapper object that will handle the calls to the model's methods so as to
99
- handle non-existent method exceptions.
100
-
101
- :param name: Model name
102
- :param model: Model object
103
- :raises HTTPInternalServerError: in case that a model has defined
104
- a reponse schema that is nod JSON schema valid (DRAFT 4)
105
- """
106
- def __init__ (self , name , model_obj ):
107
- self .name = name
108
- self .model_obj = model_obj
109
-
110
- schema = getattr (self .model_obj , "schema" , None )
111
-
112
- if isinstance (schema , dict ):
113
- try :
114
- schema = marshmallow .Schema .from_dict (
115
- schema ,
116
- name = "ModelPredictionResponse"
117
- )
118
- self .has_schema = True
119
- except Exception as e :
120
- LOG .exception (e )
121
- raise web .HTTPInternalServerError (
122
- reason = ("Model defined schema is invalid, "
123
- "check server logs." )
124
- )
125
- elif schema is not None :
126
- try :
127
- if issubclass (schema , marshmallow .Schema ):
128
- self .has_schema = True
129
- except TypeError :
130
- raise web .HTTPInternalServerError (
131
- reason = ("Model defined schema is invalid, "
132
- "check server logs." )
133
- )
134
- else :
135
- self .has_schema = False
136
-
137
- self .response_schema = schema
138
-
139
- def validate_response (self , response ):
140
- """Validate a response against the model's response schema, if set.
141
-
142
- If the wrapped model has defined a ``response`` attribute we will
143
- validate the response that
144
-
145
- :param response: The reponse that will be validated.
146
- :raises exceptions.InternalServerError: in case the reponse cannot be
147
- validated.
148
- """
149
- if self .has_schema is not True :
150
- raise web .HTTPInternalServerError (
151
- reason = ("Trying to validate against a schema, but I do not "
152
- "have one defined" )
153
- )
154
-
155
- try :
156
- self .response_schema ().load (response )
157
- except marshmallow .ValidationError as e :
158
- LOG .exception (e )
159
- raise web .HTTPInternalServerError (
160
- reason = "ERROR validating model response, check server logs."
161
- )
162
- except Exception as e :
163
- LOG .exception (e )
164
- raise web .HTTPInternalServerError (
165
- reason = "Unknown ERROR validating response, check server logs."
166
- )
167
-
168
- return True
169
-
170
- def get_metadata (self ):
171
- """Obtain model's metadata.
172
-
173
- If the model's metadata cannot be obtained because it is not
174
- implemented, we will provide some generic information so that the
175
- call does not fail.
176
-
177
- :returns dict: dictionary containing model's metadata
178
- """
179
- try :
180
- d = self .model_obj .get_metadata ()
181
- except (NotImplementedError , AttributeError ):
182
- d = {
183
- "id" : "0" ,
184
- "name" : self .name ,
185
- "description" : ("Could not load description from "
186
- "underlying model (loaded '%s')" % self .name ),
187
- }
188
- return d
189
-
190
- @catch_error
191
- def predict (self , ** kwargs ):
192
- """Perform a prediction on wrapped model's ``predict`` method.
193
-
194
- :raises HTTPNotImplemented: If the method is not
195
- implemented in the wrapper model.
196
- :raises HTTPInternalServerError: If the call produces
197
- an error
198
- :raises HTTPException: If the call produces an
199
- error, already wrapped as a HTTPException
200
- """
201
- return self .model_obj .predict (** kwargs )
202
-
203
- @catch_error
204
- def train (self , * args , ** kwargs ):
205
- """Perform a training on wrapped model's ``train`` method.
206
-
207
- :raises HTTPNotImplemented: If the method is not
208
- implemented in the wrapper model.
209
- :raises HTTPInternalServerError: If the call produces
210
- an error
211
- :raises HTTPException: If the call produces an
212
- error, already wrapped as a HTTPException
213
- """
214
- return self .model_obj .train (* args , ** kwargs )
215
-
216
- @catch_error
217
- def get_train_args (self ):
218
- """Add training arguments into the training parser.
219
-
220
- :param parser: an argparse like object
221
-
222
- This method will call the wrapped model ``add_training_args``. If the
223
- method does not exist, but the wrapped model implements the DEPRECATED
224
- ``get_train_args`` we will try to load the arguments from there.
225
- """
226
- try :
227
- return self .model_obj .get_train_args ()
228
- except (NotImplementedError , AttributeError ):
229
- return {}
230
-
231
- @catch_error
232
- def get_predict_args (self ):
233
- """Add predict arguments into the predict parser.
234
-
235
- :param parser: an argparse like object
236
-
237
- This method will call the wrapped model ``add_predict_args``. If the
238
- method does not exist, but the wrapped model implements the DEPRECATED
239
- ``get_predict_args`` we will try to load the arguments from there.
240
- """
241
- try :
242
- return self .model_obj .get_predict_args ()
243
- except (NotImplementedError , AttributeError ):
244
- return {}
0 commit comments