Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions tests/integration/llm/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@

from concurrent.futures import ThreadPoolExecutor, as_completed
from json.decoder import JSONDecodeError
from djl_python.session_utils import (HEADER_SAGEMAKER_SESSION_ID,
HEADER_SAGEMAKER_CLOSED_SESSION_ID)

FAILED_DEPENDENCY_CODE = 424
TIMEOUT = 3.0
Expand Down Expand Up @@ -1214,16 +1212,16 @@ def create_session():
res = send_json(req)
if res.status_code >= 300:
return None
session_id = res.headers.get(HEADER_SAGEMAKER_SESSION_ID).split(';')[0]
session_id = res.headers.get("X-Amzn-SageMaker-Session-Id").split(';')[0]
return session_id


def close_session(session_id):
req = {"requestType": "CLOSE"}
res = send_json(req, headers={HEADER_SAGEMAKER_SESSION_ID: session_id})
res = send_json(req, headers={"X-Amzn-SageMaker-Session-Id": session_id})
if res.status_code >= 300:
return None
session_id = res.headers.get(HEADER_SAGEMAKER_CLOSED_SESSION_ID)
session_id = res.headers.get("X-Amzn-SageMaker-Closed-Session-Id")
return session_id


Expand Down Expand Up @@ -2148,7 +2146,7 @@ def test_handler_stateful(model, model_spec):
# Create session
session_id = create_session()
if session_id is None:
raise RuntimeError("Create session failed!")
raise RuntimeError(f"Create session failed: {session_id}")
spec = model_spec[args.model]
if "worker" in spec:
check_worker_number(spec["worker"])
Expand All @@ -2166,7 +2164,8 @@ def test_handler_stateful(model, model_spec):
for stream in stream_values:
req["stream"] = stream
LOGGER.info(f"req {req}")
res = send_json(req, headers={HEADER_SAGEMAKER_SESSION_ID: session_id})
res = send_json(req,
headers={"X-Amzn-SageMaker-Session-Id": session_id})
message = res.content.decode("utf-8")
LOGGER.info(f"res: {message}")
response_checker(res, message)
Expand All @@ -2184,12 +2183,12 @@ def test_handler_stateful(model, model_spec):
req,
spec.get("tokenizer", None),
batch_size,
headers=[f"'{HEADER_SAGEMAKER_SESSION_ID}: {session_id}'"])
headers=[f"'X-Amzn-SageMaker-Session-Id: {session_id}'"])

# Close session
closed_session_id = close_session(session_id)
if closed_session_id is None:
raise RuntimeError("Close session failed!")
if closed_session_id != session_id:
raise RuntimeError(f"Close session failed: {session_id}")


def run(raw_args):
Expand Down
Loading