Skip to content

Commit 54ceee0

Browse files
authored
refactor: riseapi mode response format (#1585)
1 parent 7009be5 commit 54ceee0

File tree

4 files changed

+38
-14
lines changed

4 files changed

+38
-14
lines changed

deeppavlov/models/api_requester/api_requester.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@ class ApiRequester(Component):
3333
3434
Attributes:
3535
url: url of the API.
36-
out: count of expected returned values.
36+
out_count: count of expected returned values.
3737
param_names: list of parameter names for API requests.
3838
debatchify: if True, single instances will be sent to the API endpoint instead of batches.
3939
"""
4040

41-
def __init__(self, url: str, out: [int, list], param_names: [list, tuple] = (), debatchify: bool = False,
41+
def __init__(self, url: str, out: [int, list], param_names: [list, tuple] = None, debatchify: bool = False,
4242
*args, **kwargs):
4343
self.url = url
44+
if param_names is None:
45+
param_names = kwargs.get('in', ())
4446
self.param_names = param_names
4547
self.out_count = out if isinstance(out, int) else len(out)
4648
self.debatchify = debatchify
@@ -70,13 +72,11 @@ async def collect():
7072

7173
loop = asyncio.get_event_loop()
7274
response = loop.run_until_complete(collect())
73-
75+
if self.out_count > 1:
76+
response = list(zip(*response))
7477
else:
7578
response = requests.post(self.url, json=data).json()
7679

77-
if self.out_count > 1:
78-
response = list(zip(*response))
79-
8080
return response
8181

8282
async def get_async_response(self, data: dict, batch_size: int) -> AsyncIterable:

deeppavlov/utils/server/server.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import asyncio
16+
import os
1617
from collections import namedtuple
1718
from logging import getLogger
1819
from pathlib import Path
@@ -44,6 +45,13 @@
4445
log = getLogger(__name__)
4546
dialog_logger = DialogLogger(logger_name='rest_api')
4647

48+
COMPATIBILITY_MODE = os.getenv('COMPATIBILITY_MODE', False)
49+
50+
if COMPATIBILITY_MODE is not False:
51+
log.warning('DeepPavlov riseapi mode will use the old model response data format used up and including 1.0.0rc1.\n'
52+
'COMPATIBILITY_MODE will be removed in the DeepPavlov 1.2.0.\n'
53+
'Please, update your client code according to the new format.')
54+
4755
app = FastAPI()
4856

4957
app.add_middleware(
@@ -144,9 +152,13 @@ def interact(model: Chainer, payload: Dict[str, Optional[List]]) -> List:
144152
model_args = [arg or [None] * batch_size for arg in model_args]
145153

146154
prediction = model(*model_args)
147-
if len(model.out_params) == 1:
148-
prediction = [prediction]
149-
prediction = list(zip(*prediction))
155+
156+
# TODO: remove in 1.2.0
157+
if COMPATIBILITY_MODE is not False:
158+
if len(model.out_params) == 1:
159+
prediction = [prediction]
160+
prediction = list(zip(*prediction))
161+
150162
result = jsonify_data(prediction)
151163
dialog_logger.log_out(result)
152164
return result
@@ -204,8 +216,13 @@ async def probe(item: Batch) -> List[str]:
204216
return await loop.run_in_executor(None, test_interact, model, item.dict())
205217

206218
@app.get('/api', summary='Model argument names')
207-
async def api() -> List[str]:
208-
return model_args_names
219+
async def api() -> Dict[str, List[str]]:
220+
if COMPATIBILITY_MODE is not False:
221+
return model_args_names
222+
return {
223+
'in': model.in_x,
224+
'out': model.out_params
225+
}
209226

210227
uvicorn.run(app, host=host, port=port, log_config=log_config, ssl_version=ssl_config.version,
211228
ssl_keyfile=ssl_config.keyfile, ssl_certfile=ssl_config.certfile, timeout_keep_alive=20)

docs/integrations/rest_api.rst

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@ The command will print the used host and port. Default web service properties
2424
(host, port, POST request arguments) can be modified via changing
2525
``deeppavlov/utils/settings/server_config.json`` file.
2626

27+
.. warning::
28+
29+
Starting from the 1.0.0rc2 model response format in riseapi mode matches :class:`~deeppavlov.core.common.chainer.Chainer`
30+
response format. To start model with the old format, give the ``COMPATIBILITY_MODE`` environment variable any
31+
non-empty value (e.g. ``COMPATIBILITY_MODE=true python -m deeppavlov riseapi ...``).
32+
``COMPATIBILITY_MODE`` will be removed in DeepPavlov 1.2.0.
33+
2734
API routes
2835
----------
2936

@@ -40,8 +47,8 @@ server will send a response ``["Test passed"]`` if it is working. Requests to
4047

4148
/api
4249
""""
43-
To get model argument names send GET request to ``<host>:<port>/api``. Server
44-
will return list with argument names.
50+
To get model argument and response names send GET request to ``<host>:<port>/api``. Server
51+
will return dict with model input and output names.
4552

4653
.. _rest_api_docs:
4754

tests/test_quick_start.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -444,7 +444,7 @@ def infer_api(config_path):
444444
response_code = get_response.status_code
445445
assert response_code == 200, f"GET /api request returned error code {response_code} with {config_path}"
446446

447-
model_args_names = get_response.json()
447+
model_args_names = get_response.json()['in']
448448
post_payload = dict()
449449
for arg_name in model_args_names:
450450
arg_value = ' '.join(['qwerty'] * 10)

0 commit comments

Comments
 (0)