@@ -63,6 +63,34 @@ def register_models():
63
63
MODELS_LOADED = True
64
64
65
65
66
+ def catch_error (f ):
67
+ """Decorator to catch errors when executing the underlying methods."""
68
+
69
+ def wrap (* args , ** kwargs ):
70
+ name = args [0 ].name
71
+ try :
72
+ return f (* args , ** kwargs )
73
+ except AttributeError :
74
+ raise web .HTTPNotImplemented (
75
+ reason = ("Not implemented by underlying model (loaded '%s')" %
76
+ name )
77
+ )
78
+ except NotImplementedError :
79
+ raise web .HTTPNotImplemented (
80
+ reason = ("Model '%s' does not implement this functionality" %
81
+ name )
82
+ )
83
+ except Exception as e :
84
+ LOG .error ("An exception has happened when calling '%s' method on "
85
+ "'%s' model." % (f , name ))
86
+ LOG .exception (e )
87
+ if isinstance (e , web .HTTPException ):
88
+ raise e
89
+ else :
90
+ raise web .HTTPInternalServerError (reason = e )
91
+ return wrap
92
+
93
+
66
94
class ModelWrapper (object ):
67
95
"""Class that will wrap the loaded models before exposing them.
68
96
@@ -75,11 +103,11 @@ class ModelWrapper(object):
75
103
:raises HTTPInternalServerError: in case that a model has defined
76
104
a reponse schema that is nod JSON schema valid (DRAFT 4)
77
105
"""
78
- def __init__ (self , name , model ):
106
+ def __init__ (self , name , model_obj ):
79
107
self .name = name
80
- self .model = model
108
+ self .model_obj = model_obj
81
109
82
- schema = getattr (self .model , "schema" , None )
110
+ schema = getattr (self .model_obj , "schema" , None )
83
111
84
112
if isinstance (schema , dict ):
85
113
try :
@@ -139,29 +167,6 @@ def validate_response(self, response):
139
167
140
168
return True
141
169
142
- def _call_method (self , method , * args , ** kwargs ):
143
- try :
144
- meth = getattr (self .model , method )
145
- return meth (* args , ** kwargs )
146
- except AttributeError :
147
- raise web .HTTPNotImplemented (
148
- reason = ("Not implemented by underlying model (loaded '%s')" %
149
- self .name )
150
- )
151
- except NotImplementedError :
152
- raise web .HTTPNotImplemented (
153
- reason = ("Model '%s' does not implement this functionality" %
154
- self .name )
155
- )
156
- except Exception as e :
157
- LOG .error ("An exception has happened when calling '%s' method on "
158
- "'%s' model." % (method , self .name ))
159
- LOG .exception (e )
160
- if isinstance (e , web .HTTPException ):
161
- raise e
162
- else :
163
- raise web .HTTPInternalServerError (reason = e )
164
-
165
170
def get_metadata (self ):
166
171
"""Obtain model's metadata.
167
172
@@ -172,7 +177,7 @@ def get_metadata(self):
172
177
:returns dict: dictionary containing model's metadata
173
178
"""
174
179
try :
175
- d = self .model .get_metadata ()
180
+ d = self .model_obj .get_metadata ()
176
181
except (NotImplementedError , AttributeError ):
177
182
d = {
178
183
"id" : "0" ,
@@ -182,14 +187,7 @@ def get_metadata(self):
182
187
}
183
188
return d
184
189
185
- @property
186
- def response (self ):
187
- """Wrapped model's response schema.
188
-
189
- Check :py:attr:`deepaas.v2.base.BaseModel.response` for more details.
190
- """
191
- return getattr (self .model , "response" , None )
192
-
190
+ @catch_error
193
191
def predict (self , ** kwargs ):
194
192
"""Perform a prediction on wrapped model's ``predict`` method.
195
193
@@ -200,8 +198,9 @@ def predict(self, **kwargs):
200
198
:raises HTTPException: If the call produces an
201
199
error, already wrapped as a HTTPException
202
200
"""
203
- return self ._call_method ( " predict" , ** kwargs )
201
+ return self .model_obj . predict ( ** kwargs )
204
202
203
+ @catch_error
205
204
def train (self , * args , ** kwargs ):
206
205
"""Perform a training on wrapped model's ``train`` method.
207
206
@@ -212,8 +211,9 @@ def train(self, *args, **kwargs):
212
211
:raises HTTPException: If the call produces an
213
212
error, already wrapped as a HTTPException
214
213
"""
215
- return self ._call_method ( " train" , * args , ** kwargs )
214
+ return self .model_obj . train ( * args , ** kwargs )
216
215
216
+ @catch_error
217
217
def get_train_args (self ):
218
218
"""Add training arguments into the training parser.
219
219
@@ -224,10 +224,11 @@ def get_train_args(self):
224
224
``get_train_args`` we will try to load the arguments from there.
225
225
"""
226
226
try :
227
- return self ._call_method ( " get_train_args" )
228
- except web . HTTPNotImplemented :
227
+ return self .model_obj . get_train_args ( )
228
+ except ( NotImplementedError , AttributeError ) :
229
229
return {}
230
230
231
+ @catch_error
231
232
def get_predict_args (self ):
232
233
"""Add predict arguments into the predict parser.
233
234
@@ -238,6 +239,6 @@ def get_predict_args(self):
238
239
``get_predict_args`` we will try to load the arguments from there.
239
240
"""
240
241
try :
241
- return self ._call_method ( " get_predict_args" )
242
- except web . HTTPNotImplemented :
242
+ return self .model_obj . get_predict_args ( )
243
+ except ( NotImplementedError , AttributeError ) :
243
244
return {}
0 commit comments