Skip to content

Commit 77f36db

Browse files
committed
implementation of python async engine, which decouples receiving requests and sending responses in the python engine
1 parent 46756d5 commit 77f36db

File tree

18 files changed

+911
-313
lines changed

18 files changed

+911
-313
lines changed

engines/python/setup/djl_python/arg_parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ def python_engine_args():
9191
type=str,
9292
default="info",
9393
help="log level to use for djl_python logging")
94+
parser.add_argument(
95+
'--async-mode',
96+
required=False,
97+
action=argparse.BooleanOptionalAction,
98+
help="whether to use async python engine for comms")
9499
return parser
95100

96101
@staticmethod

engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# the specific language governing permissions and limitations under the License.
1313
import ast
1414
import logging
15-
from typing import Optional, Any, Dict, Tuple, Literal
15+
from typing import Optional, Any, Dict, Tuple, Literal, Union
1616
from pydantic import field_validator, model_validator, ConfigDict, Field
17-
from vllm import EngineArgs
17+
from vllm import EngineArgs, AsyncEngineArgs
1818
from vllm.utils import FlexibleArgumentParser
1919
from vllm.engine.arg_utils import StoreBoolean
2020

@@ -219,18 +219,21 @@ def generate_vllm_engine_arg_dict(self,
219219
vllm_engine_args.update(passthrough_vllm_engine_args)
220220
return vllm_engine_args
221221

222-
def get_engine_args(self) -> EngineArgs:
222+
def get_engine_args(self,
223+
async_engine=False
224+
) -> Union[EngineArgs, AsyncEngineArgs]:
223225
additional_vllm_engine_args = self.get_additional_vllm_engine_args()
224226
self.handle_lmi_vllm_config_conflicts(additional_vllm_engine_args)
225227
vllm_engine_arg_dict = self.generate_vllm_engine_arg_dict(
226228
additional_vllm_engine_args)
227229
logging.debug(
228230
f"Construction vLLM engine args from the following DJL configs: {vllm_engine_arg_dict}"
229231
)
230-
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
232+
arg_cls = AsyncEngineArgs if async_engine else EngineArgs
233+
parser = arg_cls.add_cli_args(FlexibleArgumentParser())
231234
args_list = construct_vllm_args_list(vllm_engine_arg_dict, parser)
232235
args = parser.parse_args(args=args_list)
233-
engine_args = EngineArgs.from_cli_args(args)
236+
engine_args = arg_cls.from_cli_args(args)
234237
# we have to do this separately because vllm converts it into a string
235238
engine_args.long_lora_scaling_factors = self.long_lora_scaling_factors
236239
# These neuron configs are not implemented in the vllm arg parser
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import asyncio
2+
import logging
3+
import time
4+
import traceback
5+
from concurrent.futures import ThreadPoolExecutor
6+
from functools import partial
7+
from threading import Thread
8+
from queue import Queue
9+
from asyncio.queues import Queue as AsyncQueue
10+
11+
from djl_python.inputs import Input
12+
from djl_python.outputs import Output
13+
from djl_python.python_sync_engine import PythonSyncEngine
14+
15+
SOCKET_ACCEPT_TIMEOUT = 30.0
16+
17+
18+
class PythonAsyncEngine(PythonSyncEngine):
19+
"""
20+
Backend engine to run python code in decoupled/async mode.
21+
Requests are forwarded from the model server and submitted to the handler.
22+
The handler returns responses as they become available, and sends them to the frontend.
23+
This assumes the frontend maintains tracking of requests and can attribute a response to a given request.
24+
"""
25+
26+
def __init__(self, args, service):
27+
super().__init__(args, service)
28+
self.output_queue = AsyncQueue()
29+
self.exception_queue = Queue()
30+
self.loop = None
31+
32+
def receive_requests(self):
33+
logging.info("receive requests loop started")
34+
while True:
35+
logging.info("receive loop start")
36+
inputs, function_name = self._prepare_inputs()
37+
38+
logging.info("submitting inference task to handler")
39+
asyncio.run_coroutine_threadsafe(
40+
self.invoke_handler_async(function_name, inputs), self.loop)
41+
42+
async def invoke_handler_async(self, function_name: str, inputs: Input):
43+
try:
44+
logging.info("new async inference call to handler ")
45+
outputs = await self.service.invoke_handler_async(
46+
function_name, inputs)
47+
logging.info("async inference call returned")
48+
if outputs is None:
49+
outputs = Output(code=204, message="No content")
50+
elif not isinstance(outputs, Output):
51+
outputs = Output().error(
52+
f"Invalid output type from handler: {type(outputs)}")
53+
except Exception as e:
54+
logging.exception("Failed invoke service.invoke_handler_async()")
55+
if (type(e).__name__ == "OutOfMemoryError"
56+
or type(e).__name__ == "MemoryError"
57+
or "No available memory for the cache blocks" in str(e)
58+
or "CUDA error: out of memory" in str(e)):
59+
outputs = Output(code=507, message=str(e))
60+
else:
61+
outputs = Output().error(str(e))
62+
logging.info(f"putting result of inference to output queue")
63+
await self.output_queue.put(outputs)
64+
65+
def send_responses(self):
66+
while True:
67+
future = asyncio.run_coroutine_threadsafe(self.output_queue.get(),
68+
self.loop)
69+
logging.info("waiting for new inference response")
70+
output: Output = future.result()
71+
logging.info(
72+
f"inference response received, sending back: {output}")
73+
output.send(self.cl_socket)
74+
75+
def run_server(self):
76+
77+
async def main():
78+
self.loop = asyncio.get_running_loop()
79+
self._create_cl_socket()
80+
81+
def catch_all(func):
82+
try:
83+
func()
84+
except Exception as e:
85+
logging.exception("{func} failed")
86+
self.exception_queue.put(str(traceback.format_exc()))
87+
88+
threads = [
89+
Thread(target=partial(catch_all, self.receive_requests)),
90+
Thread(target=partial(catch_all, self.send_responses)),
91+
]
92+
93+
for t in threads:
94+
t.start()
95+
96+
def check_threads():
97+
while True:
98+
if not all(t.is_alive() for t in threads):
99+
return
100+
time.sleep(1)
101+
102+
with ThreadPoolExecutor(1) as executor:
103+
await asyncio.get_event_loop().run_in_executor(
104+
executor, check_threads)
105+
106+
asyncio.get_event_loop().run_until_complete(main())
107+
logging.info("djl async engine terminated")
108+
if not self.exception_queue.empty():
109+
return self.exception_queue.get()
110+
return None
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import os
2+
import socket
3+
import signal
4+
import logging
5+
from typing import Tuple
6+
7+
from djl_python.service_loader import get_annotated_function, load_model_service, has_function_in_module
8+
from djl_python.inputs import Input
9+
from djl_python.outputs import Output
10+
11+
SOCKET_ACCEPT_TIMEOUT = 30.0
12+
13+
14+
class PythonSyncEngine(object):
15+
"""
16+
Backend engine to run python code
17+
"""
18+
19+
def __init__(self, args, service):
20+
# Support MPI environment args
21+
if os.getenv('OMPI_COMM_WORLD_SIZE'):
22+
os.environ["WORLD_SIZE"] = os.getenv('OMPI_COMM_WORLD_SIZE')
23+
if os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'):
24+
os.environ["LOCAL_RANK"] = os.getenv('OMPI_COMM_WORLD_LOCAL_RANK')
25+
rank = os.environ.get("OMPI_COMM_WORLD_RANK")
26+
if rank:
27+
os.environ["RANK"] = rank
28+
29+
self.model_dir = args.model_dir
30+
self.sock_type = args.sock_type
31+
self.sock_name = args.sock_name
32+
self.port = args.port
33+
self.service = service
34+
self.device_id = args.device_id
35+
self.tensor_parallel_degree = args.tensor_parallel_degree
36+
self.pipeline_parallel_degree = args.pipeline_parallel_degree
37+
self.cluster_size = args.cluster_size
38+
self.entry_point = args.entry_point
39+
self.recommended_entry_point = args.recommended_entry_point
40+
self.output_formatter = get_annotated_function(args.model_dir,
41+
"is_output_formatter")
42+
self.input_formatter = get_annotated_function(args.model_dir,
43+
"is_input_formatter")
44+
self.is_entry_point_verified = False
45+
46+
if self.sock_type == "unix":
47+
if self.sock_name is None:
48+
raise ValueError("Missing sock-name argument.")
49+
self.sock_name = f"{args.sock_name}.{rank}" if rank else args.sock_name
50+
51+
self.clean_up()
52+
elif self.sock_type == "tcp":
53+
if self.sock_name is None:
54+
self.sock_name = "0.0.0.0"
55+
if self.port is None:
56+
raise ValueError("Missing port argument.")
57+
self.port = int(self.port) + int(rank) if rank else self.port
58+
else:
59+
raise ValueError(f"Invalid socket-type: {self.sock_type}.")
60+
61+
socket_family = socket.AF_INET if self.sock_type == "tcp" else socket.AF_UNIX
62+
self.sock = socket.socket(socket_family, socket.SOCK_STREAM)
63+
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
64+
self.sock.settimeout(SOCKET_ACCEPT_TIMEOUT)
65+
self.cl_socket = None
66+
67+
def clean_up(self):
68+
pid_file = f"{self.sock_name}.pid"
69+
if os.path.exists(pid_file):
70+
with open(pid_file, "r") as f:
71+
pid = f.readline()
72+
if pid:
73+
try:
74+
os.kill(int(pid), signal.SIGKILL)
75+
logging.warning(
76+
f"{self.sock_name} - kill dangling process: {pid}")
77+
except ProcessLookupError:
78+
pass
79+
80+
with open(pid_file, "w") as f:
81+
f.write(str(os.getpid()))
82+
83+
if os.path.exists(self.sock_name):
84+
os.remove(self.sock_name)
85+
86+
def _prepare_inputs(self) -> Tuple[Input, str]:
87+
inputs = Input()
88+
inputs.read(self.cl_socket)
89+
prop = inputs.get_properties()
90+
if self.tensor_parallel_degree:
91+
prop["tensor_parallel_degree"] = self.tensor_parallel_degree
92+
if self.pipeline_parallel_degree:
93+
prop["pipeline_parallel_degree"] = self.pipeline_parallel_degree
94+
if self.cluster_size:
95+
prop["cluster_size"] = self.cluster_size
96+
prop["device_id"] = self.device_id
97+
98+
if "output_formatter" in prop:
99+
if hasattr(self.service, prop["output_formatter"]):
100+
# TODO: custom output_formatter in serving.properties is deprecated. Remove users are migrated.
101+
prop["output_formatter"] = getattr(self.service,
102+
prop["output_formatter"])
103+
elif self.output_formatter:
104+
prop["output_formatter"] = self.output_formatter
105+
106+
if self.input_formatter:
107+
prop["input_formatter"] = self.input_formatter
108+
function_name = inputs.get_function_name()
109+
if not self.is_entry_point_verified:
110+
if self.recommended_entry_point:
111+
if not has_function_in_module(self.service.module,
112+
function_name):
113+
self.service = load_model_service(
114+
self.model_dir, self.recommended_entry_point,
115+
self.device_id)
116+
logging.info(
117+
f"{self.entry_point} file has no handler function {function_name}."
118+
f"Hence choosing the LMI recommended entry point {self.recommended_entry_point}"
119+
)
120+
self.is_entry_point_verified = True
121+
return inputs, function_name
122+
123+
def _create_cl_socket(self):
124+
if self.sock_type == "unix":
125+
self.sock.bind(self.sock_name)
126+
else:
127+
logging.info(
128+
f"Socket bind on address: {self.sock_name}:{self.port}")
129+
self.sock.bind((self.sock_name, int(self.port)))
130+
131+
self.sock.listen(128)
132+
logging.info("Python engine started.")
133+
134+
(cl_socket, _) = self.sock.accept()
135+
# workaround error(35, 'Resource temporarily unavailable') on OSX
136+
cl_socket.setblocking(True)
137+
self.cl_socket = cl_socket
138+
139+
def run_server(self):
140+
"""
141+
Run the backend worker process and listen on a socket
142+
:return:
143+
"""
144+
self._create_cl_socket()
145+
146+
while True:
147+
inputs, function_name = self._prepare_inputs()
148+
try:
149+
outputs = self.service.invoke_handler(function_name, inputs)
150+
if outputs is None:
151+
outputs = Output(code=204, message="No content")
152+
elif not isinstance(outputs, Output):
153+
outputs = Output().error(
154+
f"Invalid output type: {type(outputs)}")
155+
except Exception as e:
156+
logging.exception("Failed invoke service.invoke_handler()")
157+
if (type(e).__name__ == "OutOfMemoryError"
158+
or type(e).__name__ == "MemoryError"
159+
or "No available memory for the cache blocks" in str(e)
160+
or "CUDA error: out of memory" in str(e)):
161+
outputs = Output(code=507, message=str(e))
162+
else:
163+
outputs = Output().error(str(e))
164+
165+
outputs.send(self.cl_socket)
166+
logging.debug("Outputs is sent to DJL engine.")
167+
try:
168+
outputs.execute_finalize()
169+
except Exception as e:
170+
logging.exception(f"Failed on finalize function: {e}")

engines/python/setup/djl_python/service_loader.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,10 @@ def invoke_handler(self, function_name, inputs):
2929
inputs.properties["model_dir"] = self.model_dir
3030
return getattr(self.module, function_name)(inputs)
3131

32+
async def invoke_handler_async(self, function_name, inputs):
33+
inputs.properties["model_dir"] = self.model_dir
34+
return await getattr(self.module, function_name)(inputs)
35+
3236

3337
def load_model_service(model_dir, entry_point, device_id):
3438
manifest_file = os.path.join(model_dir, "MAR-INF/MANIFEST.json")

0 commit comments

Comments
 (0)