Skip to content

Commit fcec69d

Browse files
add rest api to triton inframework (#377)
Signed-off-by: Onur Yilmaz <oyilmaz@nvidia.com>
1 parent b1d2573 commit fcec69d

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

scripts/deploy/nlp/deploy_inframework_triton.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import sys
1818

1919
import torch
20+
import uvicorn
2021

2122
from nemo_deploy import DeployPyTriton
2223

@@ -56,6 +57,20 @@ def get_args(argv):
5657
type=int,
5758
help="Version for the service",
5859
)
60+
parser.add_argument(
61+
"-sp",
62+
"--server_port",
63+
default=8080,
64+
type=int,
65+
help="Port for the REST server to listen for requests",
66+
)
67+
parser.add_argument(
68+
"-sa",
69+
"--server_address",
70+
default="0.0.0.0",
71+
type=str,
72+
help="HTTP address for the REST server",
73+
)
5974
parser.add_argument(
6075
"-trp",
6176
"--triton_port",
@@ -275,16 +290,27 @@ def nemo_deploy(argv):
275290

276291
LOGGER.info("Triton deploy function will be called.")
277292
nm.deploy()
293+
nm.run()
278294
except Exception as error:
279295
LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error))
280296
return
281297

282298
try:
299+
# start fastapi server which acts as a proxy to Pytriton server. Applies to PyTriton backend only.
300+
try:
301+
LOGGER.info("REST service will be started.")
302+
uvicorn.run(
303+
"nemo_deploy.service.fastapi_interface_to_pytriton:app",
304+
host=args.server_address,
305+
port=args.server_port,
306+
reload=True,
307+
)
308+
except Exception as error:
309+
LOGGER.error("Error message has occurred during REST service start. Error message: " + str(error))
283310
LOGGER.info("Model serving on Triton will be started.")
284311
nm.serve()
285312
except Exception as error:
286313
LOGGER.error("Error message has occurred during deploy function. Error message: " + str(error))
287-
return
288314

289315
torch.distributed.broadcast(torch.tensor([1], dtype=torch.long, device="cuda"), src=0)
290316

0 commit comments

Comments
 (0)