Skip to content

Commit ece48a3

Browse files
committed
Run training and prediction as tasks
This change implements the execution of the training and prediction calls as tasks, that are spawned in different executors using asyncio. This means that the ModelWrapper returns tasks for the predict and train methods, that are not anymore async methods. The callee must then handle the task (e.g. await for it to be completed) before rendering the result. The approach we follow is different for each call, as follows: - We use separate executor pools for trainings and predictions. The number of workers for each can be configured through the command line, although in most cases it does not make too much sense to provide a larger worker pool for trainings. - Prediction calls are spawned using a ThreadPoolExecutor. We have found several difficulties in spawning them as a processes using the ProcessPoolExecutor. Until we find the root cause of this we stick to using threads. - The prediction pool can be "warmed" meaning that the model can be initialized before the API starts. This is usefult for avoiding uncessary waits for the clients interacting with the API. - Even though Predictions are tasks, we do not allow to cancel them, since the we cannot cancel threads. However we use tasks to allow for async calls, that will be implemented in a separate patch - Trainings use a special CancellablePool, that uses processes under the hood. We use this special pool in order to be able to cancel the tasks, as the multiprocessing.futures.ProcessPoolExecutor does not allow to cancel tasks spawned there. - We implement a full fledged /train endpoint, where we can POST, DELETE, GET, etc. in order to create, delete and get information about a running training. Sem-Ver: api-break
1 parent 8f4fa64 commit ece48a3

File tree

14 files changed

+310
-71
lines changed

14 files changed

+310
-71
lines changed

deepaas/api/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ async def get_app(doc="/docs"):
5656

5757
LOG.info("Serving loaded V1 models: %s", list(model.V1_MODELS.keys()))
5858

59-
model.register_v2_models()
59+
model.register_v2_models(APP)
6060

6161
v2app = v2.get_app()
6262
APP.add_subapp("/v2", v2app)

deepaas/api/v2/predict.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,10 @@ def __init__(self, model_name, model_obj):
5757
@aiohttp_apispec.response_schema(responses.Failure(), 400)
5858
@aiohttpparser.parser.use_args(args)
5959
async def post(self, request, args):
60-
ret = await self.model_obj.predict(**args)
60+
task = self.model_obj.predict(**args)
61+
await task
62+
63+
ret = task.result()
6164

6265
if self.model_obj.has_schema:
6366
self.model_obj.validate_response(ret)

deepaas/api/v2/responses.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import marshmallow
1818
from marshmallow import fields
19+
from marshmallow import validate
1920

2021

2122
class Location(marshmallow.Schema):
@@ -56,3 +57,19 @@ class ModelMeta(marshmallow.Schema):
5657
version = fields.Str(required=False, description='Model version')
5758
url = fields.Str(required=False, description='Model url')
5859
links = fields.List(fields.Nested(Location))
60+
61+
62+
class Training(marshmallow.Schema):
63+
uuid = fields.UUID(required=True, description='Training identifier')
64+
date = fields.DateTime(required=True, description='Training start time')
65+
status = fields.Str(
66+
required=True,
67+
description='Training status',
68+
enum=["running", "error", "completed", "cancelled"],
69+
validate=validate.OneOf(["running", "error", "completed", "cancelled"])
70+
)
71+
message = fields.Str(description="Optional message explaining status")
72+
73+
74+
class TrainingList(marshmallow.Schema):
75+
trainings = fields.List(fields.Nested(Training))

deepaas/api/v2/train.py

Lines changed: 90 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,21 @@
1414
# License for the specific language governing permissions and limitations
1515
# under the License.
1616

17+
import asyncio
18+
import datetime
19+
import uuid
20+
1721
from aiohttp import web
1822
import aiohttp_apispec
23+
from oslo_log import log
1924
from webargs import aiohttpparser
2025
import webargs.core
2126

27+
from deepaas.api.v2 import responses
2228
from deepaas import model
2329

30+
LOG = log.getLogger("deepaas.api.v2.train")
31+
2432

2533
def setup_routes(app):
2634
# In the next lines we iterate over the loaded models and create the
@@ -36,6 +44,29 @@ class Handler(object):
3644
def __init__(self, model_name, model_obj):
3745
self.model_name = model_name
3846
self.model_obj = model_obj
47+
self._trainings = {}
48+
49+
def build_train_response(self, uuid_):
50+
training = self._trainings.get(uuid_, None)
51+
if training:
52+
ret = {}
53+
ret["date"] = training["date"]
54+
ret["uuid"] = uuid_
55+
56+
if training["task"].cancelled():
57+
ret["status"] = "cancelled"
58+
elif training["task"].done():
59+
exc = training["task"].exception()
60+
if exc:
61+
ret["status"] = "error"
62+
ret["message"] = "%s" % exc
63+
else:
64+
ret["status"] = "done"
65+
else:
66+
ret["status"] = "running"
67+
return ret
68+
else:
69+
return None
3970

4071
@aiohttp_apispec.docs(
4172
tags=["models"],
@@ -44,10 +75,66 @@ def __init__(self, model_name, model_obj):
4475
@aiohttp_apispec.querystring_schema(args)
4576
@aiohttpparser.parser.use_args(args)
4677
async def post(self, request, args):
47-
ret = await self.model_obj.train(**args)
48-
# FIXME(aloga): what are we returning here? We need to take
49-
# care of these responses as well.
78+
uuid_ = uuid.uuid4().hex
79+
train_task = self.model_obj.train(**args)
80+
self._trainings[uuid_] = {
81+
"date": str(datetime.datetime.now()),
82+
"task": train_task,
83+
}
84+
ret = self.build_train_response(uuid_)
5085
return web.json_response(ret)
5186

87+
@aiohttp_apispec.docs(
88+
tags=["models"],
89+
summary="Cancel a running training"
90+
)
91+
async def delete(self, request):
92+
uuid_ = request.match_info["uuid"]
93+
training = self._trainings.get(uuid_, None)
94+
if training:
95+
training["task"].cancel()
96+
try:
97+
await asyncio.wait_for(training["task"], 5)
98+
except asyncio.TimeoutError:
99+
pass
100+
LOG.info("Training %s has been cancelled" % uuid_)
101+
ret = self.build_train_response(uuid_)
102+
return web.json_response(ret)
103+
else:
104+
raise web.HTTPNotFound()
105+
106+
@aiohttp_apispec.docs(
107+
tags=["models"],
108+
summary="Get a list of trainings (running or completed)"
109+
)
110+
@aiohttp_apispec.response_schema(responses.TrainingList(), 200)
111+
async def index(self, request):
112+
113+
ret = []
114+
for uuid_, training in self._trainings.items():
115+
aux = self.build_train_response(uuid_)
116+
ret.append(aux)
117+
118+
return web.json_response(ret)
119+
120+
@aiohttp_apispec.docs(
121+
tags=["models"],
122+
summary="Get status of a training"
123+
)
124+
@aiohttp_apispec.response_schema(responses.Training(), 200)
125+
async def get(self, request):
126+
uuid_ = request.match_info["uuid"]
127+
ret = self.build_train_response(uuid_)
128+
if ret:
129+
return web.json_response(ret)
130+
else:
131+
raise web.HTTPNotFound()
132+
52133
hdlr = Handler(model_name, model_obj)
53134
app.router.add_post("/models/%s/train" % model_name, hdlr.post)
135+
app.router.add_get("/models/%s/train" % model_name, hdlr.index)
136+
app.router.add_get("/models/%s/train/{uuid}" % model_name, hdlr.get)
137+
app.router.add_delete(
138+
"/models/%s/train/{uuid}" % model_name,
139+
hdlr.delete
140+
)

deepaas/config.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,19 @@
4242
"/debug" endpoint. Default is to not provide this information. This will not
4343
provide logging information about the API itself.
4444
"""),
45-
cfg.IntOpt('model-workers',
46-
short='n',
45+
cfg.IntOpt('predict-workers',
46+
short='p',
4747
default=1,
4848
help="""
49-
Specify the number of workers *per model* that we will initialize. If using a
50-
CPU you probably want to increase this number, if using a GPU probably you want
51-
to leave it to 1. (defaults to 1)
49+
Specify the number of workers to spawn for prediction tasks. If using a CPU you
50+
probably want to increase this number, if using a GPU probably you want to
51+
leave it to 1. (defaults to 1)
52+
"""),
53+
cfg.IntOpt('train-workers',
54+
default=1,
55+
help="""
56+
Specify the number of workers to spawn for training tasks. Unless you know what
57+
you are doing you should leave this number to 1. (defaults to 1)
5258
"""),
5359
]
5460

deepaas/model/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,11 @@ def register_v1_models():
3131
return v1.register_models()
3232

3333

34-
def register_v2_models():
34+
def register_v2_models(app):
3535
"""Register V2 models.
3636
3737
This method has to be called before the API is spawned, so that we
3838
can look up the correct entry points and load the defined models.
3939
"""
4040

41-
return v2.register_models()
41+
return v2.register_models(app)

deepaas/model/v2/__init__.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
MODELS_LOADED = False
2828

2929

30-
def register_models():
30+
def register_models(app):
3131
global MODELS
3232
global MODELS_LOADED
3333

@@ -36,7 +36,7 @@ def register_models():
3636

3737
try:
3838
for name, model in loading.get_available_models("v2").items():
39-
MODELS[name] = wrapper.ModelWrapper(name, model)
39+
MODELS[name] = wrapper.ModelWrapper(name, model, app)
4040
except Exception as e:
4141
LOG.warning("Error loading models: %s", e)
4242

@@ -52,14 +52,15 @@ def register_models():
5252

5353
try:
5454
for name, model in loading.get_available_models("v1").items():
55-
MODELS[name] = wrapper.ModelWrapper(name, model)
55+
MODELS[name] = wrapper.ModelWrapper(name, model, app)
5656
except Exception as e:
5757
LOG.warning("Error loading models: %s", e)
5858

5959
if not MODELS:
6060
LOG.info("No models found with V2 or V1 namespace, loading test model")
6161
MODELS["deepaas-test"] = wrapper.ModelWrapper(
6262
"deepaas-test",
63-
test.TestModel()
63+
test.TestModel(),
64+
app
6465
)
6566
MODELS_LOADED = True

deepaas/model/v2/test.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# License for the specific language governing permissions and limitations
1515
# under the License.
1616

17+
import time
18+
1719
from oslo_log import log
1820
from webargs import fields
1921
from webargs import validate
@@ -56,8 +58,11 @@ def predict(self, **kwargs):
5658
}
5759

5860
def train(self, *args, **kwargs):
61+
sleep = kwargs.get("sleep", 1)
5962
LOG.debug("Got the following arguments: %s", args)
6063
LOG.debug("Got the following kw arguments: %s", kwargs)
64+
LOG.debug("Starting training, ending in %is" % sleep)
65+
time.sleep(sleep)
6166

6267
def get_predict_args(self):
6368
return {
@@ -81,7 +86,7 @@ def get_predict_args(self):
8186

8287
def get_train_args(self):
8388
return {
84-
"parameter_one": fields.Int(
89+
"sleep": fields.Int(
8590
required=True,
8691
descripton='This is a integer parameter, and it is '
8792
'a required one.'

0 commit comments

Comments
 (0)