Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ jobs:
- test: TestCorrectnessTrtLlm
instance: g6
failure-prefix: trtllm

- test: TestStatefulModel
instance: g6
failure-prefix: lmi
outputs:
failure_cpu: ${{ steps.test-failure.outputs.failure_cpu }}
failure_gpu: ${{ steps.test-failure.outputs.failure_gpu }}
Expand Down
6 changes: 6 additions & 0 deletions engines/python/setup/djl_python/async_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@


def create_non_stream_output(data: Union[str, dict],
properties: Optional[dict] = None,
error: Optional[str] = None,
code: Optional[int] = None) -> Output:
return _create_output(
data,
True,
"application/json",
properties=properties,
error=error,
code=code,
)
Expand All @@ -46,6 +48,7 @@ def _create_output(
data: Union[str, dict],
last_chunk: bool,
content_type: str,
properties: Optional[dict] = None,
error: Optional[str] = None,
code: Optional[int] = None,
) -> Output:
Expand All @@ -65,6 +68,9 @@ def _create_output(
response_dict["code"] = code
output = Output()
output.add_property("Content-Type", content_type)
if properties:
for k, v in properties.items():
output.add_property(k, v)
output.add(Output.binary_encode(response_dict))
return output

Expand Down
28 changes: 25 additions & 3 deletions engines/python/setup/djl_python/lmi_vllm/vllm_async_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,8 @@
from djl_python.outputs import Output
from djl_python.encode_decode import decode
from djl_python.async_utils import handle_streaming_response, create_non_stream_output
from djl_python.custom_formatter_handling import CustomFormatterHandler, CustomFormatterError

from .request_response_utils import (
from djl_python.custom_formatter_handling import CustomFormatterError, CustomFormatterHandler
from djl_python.lmi_vllm.request_response_utils import (
ProcessedRequest,
vllm_stream_output_formatter,
vllm_non_stream_output_formatter,
Expand All @@ -43,9 +42,15 @@
lmi_with_details_non_stream_output_formatter,
lmi_non_stream_output_formatter,
)
from djl_python.session_manager import SessionManager
from djl_python.session_utils import (create_session, close_session,
get_session,
session_non_stream_output_formatter)

logger = logging.getLogger(__name__)

SESSION_REQUESTS = {"NEW_SESSION": create_session, "CLOSE": close_session}


class VLLMHandler(CustomFormatterHandler):

Expand Down Expand Up @@ -119,12 +124,15 @@ async def initialize(self, properties: dict):
tool_parser=self.vllm_properties.tool_call_parser,
reasoning_parser=self.vllm_properties.reasoning_parser,
)
if properties.get("enable_stateful_sessions", "true") == "true":
self.session_manager: SessionManager = SessionManager(properties)
self.initialized = True

def preprocess_request(self, inputs: Input) -> ProcessedRequest:
batch = inputs.get_batches()
assert len(batch) == 1, "only one request per batch allowed"
raw_request = batch[0]
session = get_session(self.session_manager, raw_request)
content_type = raw_request.get_property("Content-Type")
decoded_payload = decode(raw_request, content_type)

Expand Down Expand Up @@ -160,6 +168,20 @@ def preprocess_request(self, inputs: Input) -> ProcessedRequest:
vllm_invoke_function = self.chat_completion_service.create_chat_completion
non_stream_output_formatter = vllm_non_stream_output_formatter
stream_output_formatter = vllm_stream_output_formatter
elif "requestType" in decoded_payload:
request_type = decoded_payload["requestType"]
if request_type not in SESSION_REQUESTS.keys():
raise RuntimeError(
f"invalid payload. request type must be one of {SESSION_REQUESTS.keys()}"
)
if self.session_manager is None:
raise RuntimeError(
f"invalid payload. stateful sessions not enabled, {request_type} not supported"
)
vllm_request = self.session_manager, inputs
vllm_invoke_function = SESSION_REQUESTS[request_type]
non_stream_output_formatter = session_non_stream_output_formatter
stream_output_formatter = vllm_stream_output_formatter
else:
raise RuntimeError(
"invalid payload. must contain prompt, inputs, or messages")
Expand Down
76 changes: 48 additions & 28 deletions engines/python/setup/djl_python/session_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,15 @@

class Session:

def __init__(self, session_id: str, session_root: str):
def __init__(self,
session_id: str,
session_root: str,
expiration_ts: float = None):
self.session_id = session_id
self.files_path = os.path.join(session_root, session_id)
self.expiration_ts = expiration_ts
if self.expiration_ts is None:
self.expiration_ts = self.get(".expiration_ts")

def put(self, key: str, value):
with open(self._path(key), "wb") as f:
Expand All @@ -53,13 +59,12 @@ def get(self, key: str, d=None):
return pickle.load(f)

def remove(self):
if os.path.exists(self.files_path):
logging.info(f"closing session: {self.session_id}")
shutil.rmtree(self.files_path)
return True
else:
logging.warning(f"session not found: {self.session_id}")
return False
if not os.path.exists(self.files_path):
raise ValueError(
f"session directory does not exist: {self.session_id}")
logging.info(f"closing session: {self.session_id}")
shutil.rmtree(self.files_path)
return True

def _path(self, key: str):
return os.path.join(self.files_path, key.replace("/", "-"))
Expand All @@ -78,9 +83,14 @@ def __init__(self, properties: dict):

self.sessions_path = properties.get("sessions_path", session_dir)
self.sessions_s3url = properties.get("sessions_s3url", None)
self.sessions: dict[str, Session] = {}
if not os.path.exists(self.sessions_path):
os.makedirs(self.sessions_path)

# Load previous saved sessions
for session_id in os.listdir(self.sessions_path):
self.sessions[session_id] = Session(session_id, self.sessions_path)

atexit.register(self._save_sessions_to_s3)

def create_session(self) -> Session:
Expand All @@ -90,55 +100,65 @@ def create_session(self) -> Session:
"""
self._clean_expired_session()
session_id = str(uuid.uuid4())
session = Session(session_id, self.sessions_path)
expiration_ts = time.time() + self.expiration
session = Session(session_id, self.sessions_path, expiration_ts)
self.sessions[session_id] = session
os.makedirs(session.files_path)
session.put(".creation_time", time.time())
session.put(".expiration_ts", expiration_ts)

self.cloud_watch.post("create_session")
return session

def get_session(self, session_id: str) -> Optional[Session]:
if not session_id or not UUID_PATTERN.match(session_id):
raise ValueError(f"invalid session_id: {session_id}")
if session_id == "NEW_SESSION" or not session_id:
return None

session = Session(session_id, self.sessions_path)
if not os.path.exists(session.files_path):
return self._recover_from_s3(session)
if session_id not in self.sessions:
raise ValueError(f"session not found: {session_id}")
session = self.sessions[session_id]

# Session expired
if session.expiration_ts is not None \
and time.time() > session.expiration_ts:
logging.info(f"Session expired: {session_id}")
return None

return session

def close_session(self, session_id):
if not session_id or not UUID_PATTERN.match(session_id):
if not session_id:
raise ValueError(f"invalid session_id: {session_id}")

session = Session(session_id, self.sessions_path)
if session_id not in self.sessions:
raise ValueError(f"session not found: {session_id}")
session = self.sessions[session_id]

if session.remove():
self.cloud_watch.post("close_session")
del self.sessions[session_id]

def _clean_expired_session(self):
sessions = os.listdir(self.sessions_path)
for session_id in sessions:
session = Session(session_id, self.sessions_path)
if time.time() - session.get(".creation_time") > self.expiration:
for session_id, session in list(self.sessions.items()):
if session.expiration_ts is None \
or time.time() > session.expiration_ts:
self.close_session(session_id)

def _recover_from_s3(self, session: Session) -> Optional[Session]:
def _recover_from_s3(self, session_id) -> Optional[Session]:
if not self.sessions_s3url:
return None

logging.info(f"Restoring session {session.session_id} from s3...")
os.makedirs(session.files_path)
logging.info(f"Restoring session {session_id} from s3...")
os.makedirs(self.sessions_path)
command = [
"/opt/djl/bin/s5cmd", "--retry-count", "1", "sync",
f"{self.sessions_s3url}/{session.session_id}/*",
f"{session.files_path}"
f"{self.sessions_s3url}/{session_id}/*", f"{self.sessions_path}"
]
result = sp.run(command)
if result.returncode == 0:
return session
return Session(session_id, self.sessions_path)

logging.warning(f"s5cmd download failed: {result.stderr}")
shutil.rmtree(session.files_path)
shutil.rmtree(self.sessions_path)
return None

def _save_sessions_to_s3(self):
Expand Down
85 changes: 85 additions & 0 deletions engines/python/setup/djl_python/session_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
#!/usr/bin/env python
#
# Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file
# except in compliance with the License. A copy of the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "LICENSE.txt" file accompanying this file. This file is distributed on an "AS IS"
# BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, express or implied. See the License for
# the specific language governing permissions and limitations under the License.
import datetime
import logging

from djl_python.async_utils import create_non_stream_output
from djl_python.outputs import Output

logger = logging.getLogger(__name__)

HEADER_SAGEMAKER_SESSION_ID = "X-Amzn-SageMaker-Session-Id"
HEADER_SAGEMAKER_CLOSED_SESSION_ID = "X-Amzn-SageMaker-Closed-Session-Id"


async def create_session(request):
session_manager, inputs = request
try:
session = session_manager.create_session()
expiration_ts = datetime.datetime.fromtimestamp(
session.expiration_ts).strftime("%Y-%m-%dT%H:%M:%SZ")
logger.info(f"Session {session.session_id} created")
return {
"data": {
"result": f"Session {session.session_id} created"
},
"properties": {
HEADER_SAGEMAKER_SESSION_ID:
f"{session.session_id}; Expires={expiration_ts}"
}
}
except Exception as e:
return {"error": f"Failed to create session: {str(e)}", "code": 424}


async def close_session(request):
session_manager, inputs = request
session_id = inputs.get_property(HEADER_SAGEMAKER_SESSION_ID)
try:
session_manager.close_session(session_id)
logger.info(f"Session {session_id} closed")
return {
"data": {
"result": f"Session {session_id} closed"
},
"properties": {
HEADER_SAGEMAKER_CLOSED_SESSION_ID: f"{session_id}"
}
}
except Exception as e:
return {"error": f"Failed to close session: {str(e)}", "code": 424}


def get_session(session_manager, request):
session_id = request.get_property(HEADER_SAGEMAKER_SESSION_ID)
if session_manager is None:
if session_id is not None:
raise RuntimeError(
f"invalid payload. stateful sessions not enabled, {HEADER_SAGEMAKER_SESSION_ID} header not supported"
)
return None
session = session_manager.get_session(session_id)
return session


def session_non_stream_output_formatter(
response: dict,
**_,
) -> Output:
if "error" in response:
return create_non_stream_output("",
error=response["error"],
code=response["code"])

return create_non_stream_output(response["data"],
properties=response.get("properties"))
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,23 @@ public void testLocalLoadSave()

// test session timeout
Thread.sleep(1000);
regular = new Input();
regular.addProperty("Content-Type", "application/json");
regular.addProperty("X-Amzn-SageMaker-Session-Id", sessionId);
regular.add(BytesSupplier.wrapAsJson(Map.of("action", "regular")));
ret = predictor.predict(regular);
Assert.assertEquals(ret.getProperty("Content-Type", null), "application/json");
Assert.assertTrue(ret.getAsString(0).contains("session not found"));
long count;
try (Stream<Path> files = Files.list(path)) {
count = files.count();
}
Assert.assertEquals(count, 1);

// create a new session
ret = predictor.predict(createSession);
sessionId = ret.getProperty("X-Amzn-SageMaker-Session-Id", null);
Assert.assertNotNull(sessionId);
long count;
try (Stream<Path> files = Files.list(path)) {
count = files.count();
}
Expand Down
Loading
Loading