16
16
17
17
import asyncio
18
18
import collections
19
- import concurrent .futures
20
19
import contextlib
21
20
import datetime
22
21
import functools
22
+ import io
23
23
import multiprocessing
24
24
import multiprocessing .pool
25
25
import os
50
50
51
51
.. py:attribute:: filename
52
52
53
- Complete file path to the temporary file in the filesyste ,
53
+ Complete file path to the temporary file in the filesystem ,
54
54
55
55
.. py:attribute:: content_type
56
56
61
61
Filename of the original file being uploaded.
62
62
"""
63
63
64
+ ReturnedFile = collections .namedtuple ("ReturnedFile" , ("name" ,
65
+ "filename" ,
66
+ "content_type" ,
67
+ "original_filename" ))
68
+ """Class to pass the files returned from predict in a pickable way
69
+
70
+ .. py:attribute:: name
71
+
72
+ Name of the argument where this file is being sent.
73
+
74
+ .. py:attribute:: filename
75
+
76
+ Complete file path to the temporary file in the filesystem,
77
+
78
+ .. py:attribute:: content_type
79
+
80
+ Content-type of the uploaded file
81
+
82
+ .. py:attribute:: original_filename
83
+
84
+ Filename of the original file being uploaded.
85
+ """
86
+
87
+
64
88
# set defaults to None, mainly for compatibility (vkoz)
65
89
UploadedFile .__new__ .__defaults__ = (None , None , None , None )
90
+ ReturnedFile .__new__ .__defaults__ = (None , None , None , None )
66
91
67
92
68
93
class ModelWrapper (object ):
@@ -75,7 +100,7 @@ class ModelWrapper(object):
75
100
:param name: Model name
76
101
:param model: Model object
77
102
:raises HTTPInternalServerError: in case that a model has defined
78
- a reponse schema that is nod JSON schema valid (DRAFT 4)
103
+ a response schema that is not JSON schema valid (DRAFT 4)
79
104
"""
80
105
def __init__ (self , name , model_obj , app ):
81
106
self .name = name
@@ -84,11 +109,8 @@ def __init__(self, name, model_obj, app):
84
109
85
110
self ._loop = asyncio .get_event_loop ()
86
111
87
- self ._predict_workers = CONF .predict_workers
88
- self ._predict_executor = self ._init_predict_executor ()
89
-
90
- self ._train_workers = CONF .train_workers
91
- self ._train_executor = self ._init_train_executor ()
112
+ self ._workers = CONF .workers
113
+ self ._executor = self ._init_executor ()
92
114
93
115
self ._setup_cleanup ()
94
116
@@ -125,16 +147,10 @@ def _setup_cleanup(self):
125
147
self ._app .on_cleanup .append (self ._close_executors )
126
148
127
149
async def _close_executors (self , app ):
128
- self ._train_executor .shutdown ()
129
- self ._predict_executor .shutdown ()
130
-
131
- def _init_predict_executor (self ):
132
- n = self ._predict_workers
133
- executor = concurrent .futures .ThreadPoolExecutor (max_workers = n )
134
- return executor
150
+ self ._executor .shutdown ()
135
151
136
- def _init_train_executor (self ):
137
- n = self ._train_workers
152
+ def _init_executor (self ):
153
+ n = self ._workers
138
154
executor = CancellablePool (max_workers = n )
139
155
return executor
140
156
@@ -168,7 +184,7 @@ def validate_response(self, response):
168
184
If the wrapped model has defined a ``response`` attribute we will
169
185
validate the response that
170
186
171
- :param response: The reponse that will be validated.
187
+ :param response: The response that will be validated.
172
188
:raises exceptions.InternalServerError: in case the reponse cannot be
173
189
validated.
174
190
"""
@@ -213,18 +229,10 @@ def get_metadata(self):
213
229
}
214
230
return d
215
231
216
- def _run_in_predict_pool (self , func , * args , ** kwargs ):
217
- async def task (fn ):
218
- return await self ._loop .run_in_executor (self ._predict_executor , fn )
219
-
220
- return self ._loop .create_task (
221
- task (functools .partial (func , * args , ** kwargs ))
222
- )
223
-
224
- def _run_in_train_pool (self , func , * args , ** kwargs ):
232
+ def _run_in_pool (self , func , * args , ** kwargs ):
225
233
fn = functools .partial (func , * args , ** kwargs )
226
234
ret = self ._loop .create_task (
227
- self ._train_executor .apply (fn )
235
+ self ._executor .apply (fn )
228
236
)
229
237
return ret
230
238
@@ -243,17 +251,27 @@ async def warm(self):
243
251
LOG .debug ("Cannot warm (initialize) model '%s'" % self .name )
244
252
return
245
253
246
- run = self ._loop .run_in_executor
247
- executor = self ._predict_executor
248
- n = self ._predict_workers
249
254
try :
255
+ n = self ._workers
250
256
LOG .debug ("Warming '%s' model with %s workers" % (self .name , n ))
251
- fs = [run ( executor , func ) for i in range (0 , n )]
257
+ fs = [self . _run_in_pool ( func ) for _ in range (0 , n )]
252
258
await asyncio .gather (* fs )
253
259
LOG .debug ("Model '%s' has been warmed" % self .name )
254
260
except NotImplementedError :
255
261
LOG .debug ("Cannot warm (initialize) model '%s'" % self .name )
256
262
263
+ @staticmethod
264
+ def predict_wrap (predict_func , * args , ** kwargs ):
265
+ """Wrapper function to allow returning files from predict
266
+ This wrapper exists because buffer objects are not pickable,
267
+ thus cannot be returned from the executor.
268
+ """
269
+ ret = predict_func (* args , ** kwargs )
270
+ if isinstance (ret , io .BufferedReader ):
271
+ ret = ReturnedFile (filename = ret .name )
272
+
273
+ return ret
274
+
257
275
def predict (self , * args , ** kwargs ):
258
276
"""Perform a prediction on wrapped model's ``predict`` method.
259
277
@@ -280,8 +298,8 @@ def predict(self, *args, **kwargs):
280
298
# FIXME(aloga); cleanup of tmpfile here
281
299
282
300
with self ._catch_error ():
283
- return self ._run_in_predict_pool (
284
- self .model_obj .predict , * args , ** kwargs
301
+ return self ._run_in_pool (
302
+ self .predict_wrap , self . model_obj .predict , * args , ** kwargs
285
303
)
286
304
287
305
def train (self , * args , ** kwargs ):
@@ -296,7 +314,7 @@ def train(self, *args, **kwargs):
296
314
"""
297
315
298
316
with self ._catch_error ():
299
- return self ._run_in_train_pool (
317
+ return self ._run_in_pool (
300
318
self .model_obj .train , * args , ** kwargs
301
319
)
302
320
0 commit comments