Skip to content

Commit 55a9e6d

Browse files
ignacioalvarolopez
authored andcommitted
Unify train and predict pools to fix GPU OOM
This fixes GPU out-of-memory problems that happened when we had two different pools (for predict and train). When we did train then predict sequentially (or viceversa) each pool wanted to have the whole GPU so out-of-memory errors happened. This won't fix out-of-memory errors when running parallel tasks on GPU (errors which also happened before). CPU deployments shouldn't be affected. Fixes ai4os#87 Sem-Ver: bugfix
1 parent 4dfa03d commit 55a9e6d

File tree

5 files changed

+67
-47
lines changed

5 files changed

+67
-47
lines changed

deepaas/api/v2/predict.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,10 @@ async def post(self, request, wsk_args=None):
7272
task = self.model_obj.predict(**args)
7373
await task
7474

75-
ret = task.result()
75+
ret = task.result()['output']
76+
77+
if isinstance(ret, model.v2.wrapper.ReturnedFile):
78+
ret = open(ret.filename, 'rb')
7679

7780
accept = args.get("accept", "application/json")
7881
if accept != "application/json":

deepaas/config.py

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,13 @@
4242
"/debug" endpoint. Default is to not provide this information. This will not
4343
provide logging information about the API itself.
4444
"""),
45-
cfg.IntOpt('predict-workers',
45+
cfg.IntOpt('workers',
4646
short='p',
4747
default=1,
4848
help="""
49-
Specify the number of workers to spawn for prediction tasks. If using a CPU you
50-
probably want to increase this number, if using a GPU probably you want to
51-
leave it to 1. (defaults to 1)
52-
"""),
53-
cfg.IntOpt('train-workers',
54-
default=1,
55-
help="""
56-
Specify the number of workers to spawn for training tasks. Unless you know what
57-
you are doing you should leave this number to 1. (defaults to 1)
49+
Specify the number of workers to spawn. If using a CPU you probably want to
50+
increase this number, if using a GPU probably you want to leave it to 1.
51+
(defaults to 1)
5852
"""),
5953
cfg.IntOpt('client-max-size',
6054
default=0,

deepaas/model/v2/wrapper.py

Lines changed: 53 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
import asyncio
1818
import collections
19-
import concurrent.futures
2019
import contextlib
2120
import datetime
2221
import functools
22+
import io
2323
import multiprocessing
2424
import multiprocessing.pool
2525
import os
@@ -50,7 +50,7 @@
5050
5151
.. py:attribute:: filename
5252
53-
Complete file path to the temporary file in the filesyste,
53+
Complete file path to the temporary file in the filesystem,
5454
5555
.. py:attribute:: content_type
5656
@@ -61,8 +61,33 @@
6161
Filename of the original file being uploaded.
6262
"""
6363

64+
ReturnedFile = collections.namedtuple("ReturnedFile", ("name",
65+
"filename",
66+
"content_type",
67+
"original_filename"))
68+
"""Class to pass the files returned from predict in a pickable way
69+
70+
.. py:attribute:: name
71+
72+
Name of the argument where this file is being sent.
73+
74+
.. py:attribute:: filename
75+
76+
Complete file path to the temporary file in the filesystem,
77+
78+
.. py:attribute:: content_type
79+
80+
Content-type of the uploaded file
81+
82+
.. py:attribute:: original_filename
83+
84+
Filename of the original file being uploaded.
85+
"""
86+
87+
6488
# set defaults to None, mainly for compatibility (vkoz)
6589
UploadedFile.__new__.__defaults__ = (None, None, None, None)
90+
ReturnedFile.__new__.__defaults__ = (None, None, None, None)
6691

6792

6893
class ModelWrapper(object):
@@ -75,7 +100,7 @@ class ModelWrapper(object):
75100
:param name: Model name
76101
:param model: Model object
77102
:raises HTTPInternalServerError: in case that a model has defined
78-
a reponse schema that is nod JSON schema valid (DRAFT 4)
103+
a response schema that is not JSON schema valid (DRAFT 4)
79104
"""
80105
def __init__(self, name, model_obj, app):
81106
self.name = name
@@ -84,11 +109,8 @@ def __init__(self, name, model_obj, app):
84109

85110
self._loop = asyncio.get_event_loop()
86111

87-
self._predict_workers = CONF.predict_workers
88-
self._predict_executor = self._init_predict_executor()
89-
90-
self._train_workers = CONF.train_workers
91-
self._train_executor = self._init_train_executor()
112+
self._workers = CONF.workers
113+
self._executor = self._init_executor()
92114

93115
self._setup_cleanup()
94116

@@ -125,16 +147,10 @@ def _setup_cleanup(self):
125147
self._app.on_cleanup.append(self._close_executors)
126148

127149
async def _close_executors(self, app):
128-
self._train_executor.shutdown()
129-
self._predict_executor.shutdown()
130-
131-
def _init_predict_executor(self):
132-
n = self._predict_workers
133-
executor = concurrent.futures.ThreadPoolExecutor(max_workers=n)
134-
return executor
150+
self._executor.shutdown()
135151

136-
def _init_train_executor(self):
137-
n = self._train_workers
152+
def _init_executor(self):
153+
n = self._workers
138154
executor = CancellablePool(max_workers=n)
139155
return executor
140156

@@ -168,7 +184,7 @@ def validate_response(self, response):
168184
If the wrapped model has defined a ``response`` attribute we will
169185
validate the response that
170186
171-
:param response: The reponse that will be validated.
187+
:param response: The response that will be validated.
172188
:raises exceptions.InternalServerError: in case the reponse cannot be
173189
validated.
174190
"""
@@ -213,18 +229,10 @@ def get_metadata(self):
213229
}
214230
return d
215231

216-
def _run_in_predict_pool(self, func, *args, **kwargs):
217-
async def task(fn):
218-
return await self._loop.run_in_executor(self._predict_executor, fn)
219-
220-
return self._loop.create_task(
221-
task(functools.partial(func, *args, **kwargs))
222-
)
223-
224-
def _run_in_train_pool(self, func, *args, **kwargs):
232+
def _run_in_pool(self, func, *args, **kwargs):
225233
fn = functools.partial(func, *args, **kwargs)
226234
ret = self._loop.create_task(
227-
self._train_executor.apply(fn)
235+
self._executor.apply(fn)
228236
)
229237
return ret
230238

@@ -243,17 +251,27 @@ async def warm(self):
243251
LOG.debug("Cannot warm (initialize) model '%s'" % self.name)
244252
return
245253

246-
run = self._loop.run_in_executor
247-
executor = self._predict_executor
248-
n = self._predict_workers
249254
try:
255+
n = self._workers
250256
LOG.debug("Warming '%s' model with %s workers" % (self.name, n))
251-
fs = [run(executor, func) for i in range(0, n)]
257+
fs = [self._run_in_pool(func) for _ in range(0, n)]
252258
await asyncio.gather(*fs)
253259
LOG.debug("Model '%s' has been warmed" % self.name)
254260
except NotImplementedError:
255261
LOG.debug("Cannot warm (initialize) model '%s'" % self.name)
256262

263+
@staticmethod
264+
def predict_wrap(predict_func, *args, **kwargs):
265+
"""Wrapper function to allow returning files from predict
266+
This wrapper exists because buffer objects are not pickable,
267+
thus cannot be returned from the executor.
268+
"""
269+
ret = predict_func(*args, **kwargs)
270+
if isinstance(ret, io.BufferedReader):
271+
ret = ReturnedFile(filename=ret.name)
272+
273+
return ret
274+
257275
def predict(self, *args, **kwargs):
258276
"""Perform a prediction on wrapped model's ``predict`` method.
259277
@@ -280,8 +298,8 @@ def predict(self, *args, **kwargs):
280298
# FIXME(aloga); cleanup of tmpfile here
281299

282300
with self._catch_error():
283-
return self._run_in_predict_pool(
284-
self.model_obj.predict, *args, **kwargs
301+
return self._run_in_pool(
302+
self.predict_wrap, self.model_obj.predict, *args, **kwargs
285303
)
286304

287305
def train(self, *args, **kwargs):
@@ -296,7 +314,7 @@ def train(self, *args, **kwargs):
296314
"""
297315

298316
with self._catch_error():
299-
return self._run_in_train_pool(
317+
return self._run_in_pool(
300318
self.model_obj.train, *args, **kwargs
301319
)
302320

deepaas/tests/test_v2_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ async def test_dummy_model_with_wrapper(self, m_clean):
162162
w = v2_wrapper.ModelWrapper("foo", v2_test.TestModel(), self.app)
163163
task = w.predict()
164164
await task
165-
ret = task.result()
165+
ret = task.result()['output']
166166
self.assertDictEqual(
167167
{'date': '2019-01-1',
168168
'labels': [{'label': 'foo', 'probability': 1.0}]},
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
fixes:
3+
- |
4+
Fix [#83](https://github.com/indigo-dc/DEEPaaS/issues/87) out out memory
5+
errors due to the usage of two different executor pools.

0 commit comments

Comments
 (0)