Skip to content

Commit 1bcd17b

Browse files
committed
implementation of python async engine, which decouples receiving requests and sending responses in the python engine
1 parent 6002a0d commit 1bcd17b

File tree

22 files changed

+1110
-317
lines changed

22 files changed

+1110
-317
lines changed

engines/python/setup/djl_python/arg_parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ def python_engine_args():
9191
type=str,
9292
default="info",
9393
help="log level to use for djl_python logging")
94+
parser.add_argument(
95+
'--async-mode',
96+
required=False,
97+
action=argparse.BooleanOptionalAction,
98+
help="whether to use async python engine for comms")
9499
return parser
95100

96101
@staticmethod

engines/python/setup/djl_python/outputs.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ def add_property(self, key, val):
9090
self.properties[key] = val
9191
return self
9292

93+
def get_properties(self):
94+
return self.properties
95+
96+
def get_property(self, key):
97+
return self.properties.get(key)
98+
9399
def add(self, value, key=None, batch_index=None):
94100
if key is not None and type(key) is not str:
95101
logging.warning(f"Output key should be str type, got {type(key)}")
@@ -182,7 +188,7 @@ def send(self, cl_socket):
182188
msg += struct.pack('>h', len(self.properties))
183189
for k, v in self.properties.items():
184190
self.write_utf8(msg, k)
185-
self.write_utf8(msg, v)
191+
self.write_utf8(msg, str(v))
186192

187193
if self.stream_content is None:
188194
size = self.content.size()

engines/python/setup/djl_python/properties_manager/vllm_rb_properties.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212
# the specific language governing permissions and limitations under the License.
1313
import ast
1414
import logging
15-
from typing import Optional, Any, Dict, Tuple, Literal
15+
from typing import Optional, Any, Dict, Tuple, Literal, Union
1616
from pydantic import field_validator, model_validator, ConfigDict, Field
17-
from vllm import EngineArgs
17+
from vllm import EngineArgs, AsyncEngineArgs
1818
from vllm.utils import FlexibleArgumentParser
1919
from vllm.engine.arg_utils import StoreBoolean
2020

@@ -219,18 +219,21 @@ def generate_vllm_engine_arg_dict(self,
219219
vllm_engine_args.update(passthrough_vllm_engine_args)
220220
return vllm_engine_args
221221

222-
def get_engine_args(self) -> EngineArgs:
222+
def get_engine_args(self,
223+
async_engine=False
224+
) -> Union[EngineArgs, AsyncEngineArgs]:
223225
additional_vllm_engine_args = self.get_additional_vllm_engine_args()
224226
self.handle_lmi_vllm_config_conflicts(additional_vllm_engine_args)
225227
vllm_engine_arg_dict = self.generate_vllm_engine_arg_dict(
226228
additional_vllm_engine_args)
227229
logging.debug(
228230
f"Construction vLLM engine args from the following DJL configs: {vllm_engine_arg_dict}"
229231
)
230-
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
232+
arg_cls = AsyncEngineArgs if async_engine else EngineArgs
233+
parser = arg_cls.add_cli_args(FlexibleArgumentParser())
231234
args_list = construct_vllm_args_list(vllm_engine_arg_dict, parser)
232235
args = parser.parse_args(args=args_list)
233-
engine_args = EngineArgs.from_cli_args(args)
236+
engine_args = arg_cls.from_cli_args(args)
234237
# we have to do this separately because vllm converts it into a string
235238
engine_args.long_lora_scaling_factors = self.long_lora_scaling_factors
236239
# These neuron configs are not implemented in the vllm arg parser
Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
14+
import asyncio
15+
import logging
16+
import time
17+
import traceback
18+
from concurrent.futures import ThreadPoolExecutor
19+
from functools import partial
20+
from threading import Thread
21+
from queue import Queue
22+
from asyncio.queues import Queue as AsyncQueue
23+
24+
from djl_python.inputs import Input
25+
from djl_python.outputs import Output
26+
from djl_python.python_sync_engine import PythonSyncEngine
27+
28+
REQUEST_TRACKING_ID_KEY = "request_tracking_id"
29+
30+
31+
class PythonAsyncEngine(PythonSyncEngine):
32+
"""
33+
Backend engine to run python code in decoupled/async mode.
34+
Requests are forwarded from the model server and submitted to the handler.
35+
The handler returns responses as they become available, and sends them to the frontend.
36+
Requests are tracked/coordinated via the request_tracking_id property.
37+
This is an internal property set/managed by the model server.
38+
"""
39+
40+
def __init__(self, args, service):
41+
super().__init__(args, service)
42+
self.output_queue = AsyncQueue()
43+
self.exception_queue = Queue()
44+
self.loop = None
45+
# Todo: for async mode we should maybe consider
46+
47+
def receive_requests(self):
48+
logging.info("starting receive requests thread")
49+
while True:
50+
inputs, function_name = self._prepare_inputs()
51+
logging.debug(
52+
f"received new request with tracking_id {inputs.get_property(REQUEST_TRACKING_ID_KEY)}, submitting to handler"
53+
)
54+
asyncio.run_coroutine_threadsafe(
55+
self.invoke_handler(function_name, inputs), self.loop)
56+
57+
async def invoke_handler(self, function_name: str, inputs: Input):
58+
request_tracking_id = inputs.get_property(REQUEST_TRACKING_ID_KEY)
59+
try:
60+
outputs = await self.service.invoke_handler_async(
61+
function_name, inputs, self.cl_socket)
62+
except Exception as e:
63+
logging.exception("Failed invoke service.invoke_handler_async()")
64+
if (type(e).__name__ == "OutOfMemoryError"
65+
or type(e).__name__ == "MemoryError"
66+
or "No available memory for the cache blocks" in str(e)
67+
or "CUDA error: out of memory" in str(e)):
68+
logging.exception(
69+
f"Memory Error encountered when invoking module {self.service.module}, function {function_name}"
70+
)
71+
outputs = Output(code=507, message=str(e))
72+
outputs.add_property(REQUEST_TRACKING_ID_KEY,
73+
request_tracking_id)
74+
else:
75+
logging.exception(
76+
f"service.invoke_handler_async() failure. There was an error invoking module {self.service.module}, function {function_name}"
77+
)
78+
outputs = Output().error(
79+
str(e), message="service.invoke_handler_async() failure")
80+
outputs.add_property(REQUEST_TRACKING_ID_KEY,
81+
request_tracking_id)
82+
if outputs is None:
83+
outputs = Output(code=204, message="No content")
84+
logging.debug(
85+
"empty response received from service.invoke_handler_async()")
86+
outputs.add_property(REQUEST_TRACKING_ID_KEY, request_tracking_id)
87+
elif not isinstance(outputs, Output):
88+
message = (
89+
f"Invalid type returned from {self.service.module}.{function_name}. "
90+
f"Received type {type(outputs)}, does not match expected type djl_python.outputs.Output"
91+
)
92+
logging.error(message)
93+
outputs = Output().error(message)
94+
outputs.add_property(REQUEST_TRACKING_ID_KEY, request_tracking_id)
95+
logging.info(f"putting result of inference to output queue")
96+
await self.output_queue.put(outputs)
97+
98+
def send_responses(self):
99+
logging.info("starting send responses thread")
100+
while True:
101+
future = asyncio.run_coroutine_threadsafe(self.output_queue.get(),
102+
self.loop)
103+
logging.debug("waiting for new inference response")
104+
output = future.result()
105+
output.send(self.cl_socket)
106+
107+
def run_server(self):
108+
109+
async def main():
110+
self.loop = asyncio.get_running_loop()
111+
self._create_cl_socket()
112+
113+
def catch_all(func):
114+
try:
115+
func()
116+
except Exception as e:
117+
logging.error(f"{func} failed. Details {e}")
118+
self.exception_queue.put(str(traceback.format_exc()))
119+
120+
threads = [
121+
Thread(target=partial(catch_all, self.receive_requests)),
122+
Thread(target=partial(catch_all, self.send_responses)),
123+
]
124+
125+
for thread in threads:
126+
thread.start()
127+
128+
def check_threads():
129+
while True:
130+
if not all(t.is_alive() for t in threads):
131+
return
132+
time.sleep(1)
133+
134+
with ThreadPoolExecutor(1) as executor:
135+
await asyncio.get_event_loop().run_in_executor(
136+
executor, check_threads)
137+
138+
asyncio.get_event_loop().run_until_complete(main())
139+
if not self.exception_queue.empty():
140+
logging.error(
141+
f"djl async engine terminated with error {self.exception_queue.get()}"
142+
)
143+
logging.info("djl async engine terminated")
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
14+
import os
15+
import socket
16+
import signal
17+
import logging
18+
from typing import Tuple
19+
20+
from djl_python.service_loader import get_annotated_function, load_model_service, has_function_in_module
21+
from djl_python.inputs import Input
22+
from djl_python.outputs import Output
23+
24+
SOCKET_ACCEPT_TIMEOUT = 30.0
25+
26+
27+
class PythonSyncEngine(object):
28+
"""
29+
Backend engine to run python code
30+
"""
31+
32+
def __init__(self, args, service):
33+
# Support MPI environment args
34+
if os.getenv('OMPI_COMM_WORLD_SIZE'):
35+
os.environ["WORLD_SIZE"] = os.getenv('OMPI_COMM_WORLD_SIZE')
36+
if os.getenv('OMPI_COMM_WORLD_LOCAL_RANK'):
37+
os.environ["LOCAL_RANK"] = os.getenv('OMPI_COMM_WORLD_LOCAL_RANK')
38+
rank = os.environ.get("OMPI_COMM_WORLD_RANK")
39+
if rank:
40+
os.environ["RANK"] = rank
41+
42+
self.model_dir = args.model_dir
43+
self.sock_type = args.sock_type
44+
self.sock_name = args.sock_name
45+
self.port = args.port
46+
self.service = service
47+
self.device_id = args.device_id
48+
self.tensor_parallel_degree = args.tensor_parallel_degree
49+
self.pipeline_parallel_degree = args.pipeline_parallel_degree
50+
self.cluster_size = args.cluster_size
51+
self.entry_point = args.entry_point
52+
self.recommended_entry_point = args.recommended_entry_point
53+
self.output_formatter = get_annotated_function(args.model_dir,
54+
"is_output_formatter")
55+
self.input_formatter = get_annotated_function(args.model_dir,
56+
"is_input_formatter")
57+
self.is_entry_point_verified = False
58+
59+
if self.sock_type == "unix":
60+
if self.sock_name is None:
61+
raise ValueError("Missing sock-name argument.")
62+
self.sock_name = f"{args.sock_name}.{rank}" if rank else args.sock_name
63+
64+
self.clean_up()
65+
elif self.sock_type == "tcp":
66+
if self.sock_name is None:
67+
self.sock_name = "0.0.0.0"
68+
if self.port is None:
69+
raise ValueError("Missing port argument.")
70+
self.port = int(self.port) + int(rank) if rank else self.port
71+
else:
72+
raise ValueError(f"Invalid socket-type: {self.sock_type}.")
73+
74+
socket_family = socket.AF_INET if self.sock_type == "tcp" else socket.AF_UNIX
75+
self.sock = socket.socket(socket_family, socket.SOCK_STREAM)
76+
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
77+
self.sock.settimeout(SOCKET_ACCEPT_TIMEOUT)
78+
self.cl_socket = None
79+
80+
def clean_up(self):
81+
pid_file = f"{self.sock_name}.pid"
82+
if os.path.exists(pid_file):
83+
with open(pid_file, "r") as f:
84+
pid = f.readline()
85+
if pid:
86+
try:
87+
os.kill(int(pid), signal.SIGKILL)
88+
logging.warning(
89+
f"{self.sock_name} - kill dangling process: {pid}")
90+
except ProcessLookupError:
91+
pass
92+
93+
with open(pid_file, "w") as f:
94+
f.write(str(os.getpid()))
95+
96+
if os.path.exists(self.sock_name):
97+
os.remove(self.sock_name)
98+
99+
def _prepare_inputs(self) -> Tuple[Input, str]:
100+
inputs = Input()
101+
inputs.read(self.cl_socket)
102+
prop = inputs.get_properties()
103+
if self.tensor_parallel_degree:
104+
prop["tensor_parallel_degree"] = self.tensor_parallel_degree
105+
if self.pipeline_parallel_degree:
106+
prop["pipeline_parallel_degree"] = self.pipeline_parallel_degree
107+
if self.cluster_size:
108+
prop["cluster_size"] = self.cluster_size
109+
prop["device_id"] = self.device_id
110+
111+
if "output_formatter" in prop:
112+
if hasattr(self.service, prop["output_formatter"]):
113+
# TODO: custom output_formatter in serving.properties is deprecated. Remove users are migrated.
114+
prop["output_formatter"] = getattr(self.service,
115+
prop["output_formatter"])
116+
elif self.output_formatter:
117+
prop["output_formatter"] = self.output_formatter
118+
119+
if self.input_formatter:
120+
prop["input_formatter"] = self.input_formatter
121+
function_name = inputs.get_function_name()
122+
if not self.is_entry_point_verified:
123+
if self.recommended_entry_point:
124+
if not has_function_in_module(self.service.module,
125+
function_name):
126+
self.service = load_model_service(
127+
self.model_dir, self.recommended_entry_point,
128+
self.device_id)
129+
logging.info(
130+
f"{self.entry_point} file has no handler function {function_name}."
131+
f"Hence choosing the LMI recommended entry point {self.recommended_entry_point}"
132+
)
133+
self.is_entry_point_verified = True
134+
return inputs, function_name
135+
136+
def _create_cl_socket(self):
137+
if self.sock_type == "unix":
138+
self.sock.bind(self.sock_name)
139+
else:
140+
logging.info(
141+
f"Socket bind on address: {self.sock_name}:{self.port}")
142+
self.sock.bind((self.sock_name, int(self.port)))
143+
144+
self.sock.listen(128)
145+
logging.info("Python engine started.")
146+
147+
(cl_socket, _) = self.sock.accept()
148+
# workaround error(35, 'Resource temporarily unavailable') on OSX
149+
cl_socket.setblocking(True)
150+
self.cl_socket = cl_socket
151+
152+
def run_server(self):
153+
"""
154+
Run the backend worker process and listen on a socket
155+
:return:
156+
"""
157+
self._create_cl_socket()
158+
159+
while True:
160+
inputs, function_name = self._prepare_inputs()
161+
try:
162+
outputs = self.service.invoke_handler(function_name, inputs)
163+
if outputs is None:
164+
outputs = Output(code=204, message="No content")
165+
elif not isinstance(outputs, Output):
166+
outputs = Output().error(
167+
f"Invalid output type: {type(outputs)}")
168+
except Exception as e:
169+
logging.exception("Failed invoke service.invoke_handler()")
170+
if (type(e).__name__ == "OutOfMemoryError"
171+
or type(e).__name__ == "MemoryError"
172+
or "No available memory for the cache blocks" in str(e)
173+
or "CUDA error: out of memory" in str(e)):
174+
outputs = Output(code=507, message=str(e))
175+
else:
176+
outputs = Output().error(str(e))
177+
178+
outputs.send(self.cl_socket)
179+
logging.debug("Outputs is sent to DJL engine.")
180+
try:
181+
outputs.execute_finalize()
182+
except Exception as e:
183+
logging.exception(f"Failed on finalize function: {e}")

0 commit comments

Comments
 (0)