Skip to content

Commit ec4fa1d

Browse files
committed
[python] Support session based sticky routing in async mode
1 parent a4bc0b2 commit ec4fa1d

File tree

13 files changed

+508
-18
lines changed

13 files changed

+508
-18
lines changed

.github/workflows/integration.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,9 @@ jobs:
173173
- test: TestCorrectnessTrtLlm
174174
instance: g6
175175
failure-prefix: trtllm
176-
176+
- test: TestStickyRouting
177+
instance: g6
178+
failure-prefix: lmi
177179
outputs:
178180
failure_cpu: ${{ steps.test-failure.outputs.failure_cpu }}
179181
failure_gpu: ${{ steps.test-failure.outputs.failure_gpu }}
@@ -271,4 +273,4 @@ jobs:
271273
./stop_instance.sh $instance_id
272274
273275
instance_id=${{ needs.create-runners.outputs.cpu_instance_id }}
274-
./stop_instance.sh $instance_id
276+
./stop_instance.sh $instance_id

engines/python/setup/djl_python/custom_formatter_handling.py renamed to engines/python/setup/djl_python/base_handler.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@
1212
# the specific language governing permissions and limitations under the License.
1313
import logging
1414

15+
from djl_python.inputs import Input
16+
from djl_python.outputs import Output
1517
from djl_python.service_loader import get_annotated_function
1618

1719
logger = logging.getLogger(__name__)
1820

21+
HEADER_SAGEMAKER_SESSION_ID = "X-Amzn-SageMaker-Session-Id"
22+
1923

2024
class CustomFormatterError(Exception):
2125
"""Exception raised when custom formatter code fails"""
@@ -26,11 +30,12 @@ def __init__(self, message: str, original_exception: Exception):
2630
self.__cause__ = original_exception
2731

2832

29-
class CustomFormatterHandler:
33+
class BaseHandler:
3034

3135
def __init__(self):
3236
self.output_formatter = None
3337
self.input_formatter = None
38+
self.session_manager = None
3439

3540
def load_formatters(self, model_dir: str):
3641
"""Load custom formatters from model.py"""
@@ -79,3 +84,34 @@ async def apply_output_formatter_streaming_raw(self, stream_generator):
7984
logger.exception("Streaming formatter failed")
8085
raise CustomFormatterError(
8186
"Custom streaming formatter execution failed", e)
87+
88+
async def create_session(self, inputs: Input):
89+
outputs = Output()
90+
try:
91+
session = self.session_manager.create_session()
92+
outputs.add_property(HEADER_SAGEMAKER_SESSION_ID,
93+
session.session_id)
94+
outputs.add_property("Content-Type", "application/json")
95+
outputs.add(Output.binary_encode(
96+
{"result": f"Session {session.session_id} created"}),
97+
key="result")
98+
except Exception as e:
99+
return Output().error("create_session_error", message=str(e))
100+
101+
logger.info(f"Session {session.session_id} created")
102+
return outputs
103+
104+
async def close_session(self, inputs: Input):
105+
outputs = Output()
106+
session_id = inputs.get_property(HEADER_SAGEMAKER_SESSION_ID)
107+
try:
108+
self.session_manager.close_session(session_id)
109+
outputs.add_property("Content-Type", "application/json")
110+
outputs.add(Output.binary_encode(
111+
{"result": f"Session {session_id} closed"}),
112+
key="result")
113+
except Exception as e:
114+
return Output().error("close_session_error", message=str(e))
115+
116+
logger.info(f"Session {session_id} closed")
117+
return outputs

engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,8 @@
3131
from djl_python.outputs import Output
3232
from djl_python.encode_decode import decode
3333
from djl_python.async_utils import handle_streaming_response, create_non_stream_output
34-
from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError
34+
from djl_python.base_handler import BaseHandler, CustomFormatterError
35+
from djl_python.session_manager import SessionManager
3536

3637
from .request_response_utils import (
3738
ProcessedRequest,
@@ -46,8 +47,10 @@
4647

4748
logger = logging.getLogger(__name__)
4849

50+
HEADER_SAGEMAKER_SESSION_ID = "X-Amzn-SageMaker-Session-Id"
4951

50-
class VLLMHandler(CustomFormatterHandler):
52+
53+
class VLLMHandler(BaseHandler):
5154

5255
def __init__(self):
5356
super().__init__()
@@ -119,12 +122,20 @@ async def initialize(self, properties: dict):
119122
tool_parser=self.vllm_properties.tool_call_parser,
120123
reasoning_parser=self.vllm_properties.reasoning_parser,
121124
)
125+
self.session_manager: SessionManager = SessionManager(properties)
122126
self.initialized = True
123127

124128
def preprocess_request(self, inputs: Input) -> ProcessedRequest:
125129
batch = inputs.get_batches()
126130
assert len(batch) == 1, "only one request per batch allowed"
127131
raw_request = batch[0]
132+
133+
# Get session id
134+
session_id = raw_request.get_property(HEADER_SAGEMAKER_SESSION_ID)
135+
session = self.session_manager.get_session(session_id)
136+
if session is None:
137+
raise RuntimeError(f"Requested session {session_id} not found")
138+
128139
content_type = raw_request.get_property("Content-Type")
129140
decoded_payload = decode(raw_request, content_type)
130141

@@ -226,6 +237,24 @@ async def inference(
226237
service = VLLMHandler()
227238

228239

240+
async def create_session(inputs: Input) -> Output:
241+
if not service.initialized:
242+
await service.initialize(inputs.get_properties())
243+
logger.info("vllm service initialized")
244+
245+
outputs = await service.create_session(inputs)
246+
return outputs
247+
248+
249+
async def close_session(inputs: Input) -> Output:
250+
if not service.initialized:
251+
await service.initialize(inputs.get_properties())
252+
logger.info("vllm service initialized")
253+
254+
outputs = await service.close_session(inputs)
255+
return outputs
256+
257+
229258
async def handle(
230259
inputs: Input
231260
) -> Optional[Union[Output, AsyncGenerator[Output, None]]]:

engines/python/setup/djl_python/session_manager.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def remove(self):
5858
shutil.rmtree(self.files_path)
5959
return True
6060
else:
61-
logging.warning(f"session not found: {self.session_id}")
61+
logging.warning(f"session directory not found: {self.session_id}")
6262
return False
6363

6464
def _path(self, key: str):
@@ -80,6 +80,7 @@ def __init__(self, properties: dict):
8080
self.sessions_s3url = properties.get("sessions_s3url", None)
8181
if not os.path.exists(self.sessions_path):
8282
os.makedirs(self.sessions_path)
83+
self.sessions: dict[str, Session] = {}
8384

8485
atexit.register(self._save_sessions_to_s3)
8586

@@ -95,25 +96,32 @@ def create_session(self) -> Session:
9596
session.put(".creation_time", time.time())
9697

9798
self.cloud_watch.post("create_session")
99+
self.sessions[session_id] = session
98100
return session
99101

100102
def get_session(self, session_id: str) -> Optional[Session]:
101103
if not session_id or not UUID_PATTERN.match(session_id):
102104
raise ValueError(f"invalid session_id: {session_id}")
103105

104-
session = Session(session_id, self.sessions_path)
106+
if session_id not in self.sessions:
107+
raise ValueError(f"Session not found: {session_id}")
108+
session = self.sessions[session_id]
109+
105110
if not os.path.exists(session.files_path):
106111
return self._recover_from_s3(session)
107112

108113
return session
109114

110115
def close_session(self, session_id):
111116
if not session_id or not UUID_PATTERN.match(session_id):
112-
raise ValueError(f"invalid session_id: {session_id}")
117+
raise ValueError(f"Invalid session_id: {session_id}")
113118

114-
session = Session(session_id, self.sessions_path)
119+
if session_id not in self.sessions:
120+
raise ValueError(f"Session not found: {session_id}")
121+
session = self.sessions[session_id]
115122
if session.remove():
116123
self.cloud_watch.post("close_session")
124+
self.sessions.pop(session_id)
117125

118126
def _clean_expired_session(self):
119127
sessions = os.listdir(self.sessions_path)
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#!/usr/bin/env python
2+
#
3+
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
6+
# except in compliance with the License. A copy of the License is located at
7+
#
8+
# http://aws.amazon.com/apache2.0/
9+
#
10+
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
11+
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
12+
# the specific language governing permissions and limitations under the License.
13+
14+
HEADER_SAGEMAKER_SESSION_ID = "X-Amzn-SageMaker-Session-Id"
15+
16+
17+
def get_session(session_manager, request):
18+
session_id = request.get_property(HEADER_SAGEMAKER_SESSION_ID)
19+
if session_id is None:
20+
return None
21+
session = session_manager.get_session(session_id)
22+
return session

engines/python/src/main/java/ai/djl/python/engine/PyProcess.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,8 @@ Output predict(Input inputs, int timeout, boolean initialLoad) throws TranslateE
158158
// In RollingBatch, we queue adapter loading jobs to occur after the initial load.
159159
// Executing those in RollingBatch context doesn't work, so we need to handle them in the
160160
// 'standard' way.
161-
if (initialLoad || inputs.getProperty("handler", null) != null) {
161+
if (initialLoad
162+
|| (inputs.getProperty("handler", null) != null && asyncRequestManager == null)) {
162163
return predictStandard(inputs, timeout, initialLoad);
163164
}
164165
if (rollingBatch != null) {

0 commit comments

Comments
 (0)