diff --git a/tests/integration/llm/client.py b/tests/integration/llm/client.py index 648fe930e..43ca75496 100644 --- a/tests/integration/llm/client.py +++ b/tests/integration/llm/client.py @@ -16,8 +16,6 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from json.decoder import JSONDecodeError -from djl_python.session_utils import (HEADER_SAGEMAKER_SESSION_ID, - HEADER_SAGEMAKER_CLOSED_SESSION_ID) FAILED_DEPENDENCY_CODE = 424 TIMEOUT = 3.0 @@ -1214,16 +1212,16 @@ def create_session(): res = send_json(req) if res.status_code >= 300: return None - session_id = res.headers.get(HEADER_SAGEMAKER_SESSION_ID).split(';')[0] + session_id = res.headers.get("X-Amzn-SageMaker-Session-Id").split(';')[0] return session_id def close_session(session_id): req = {"requestType": "CLOSE"} - res = send_json(req, headers={HEADER_SAGEMAKER_SESSION_ID: session_id}) + res = send_json(req, headers={"X-Amzn-SageMaker-Session-Id": session_id}) if res.status_code >= 300: return None - session_id = res.headers.get(HEADER_SAGEMAKER_CLOSED_SESSION_ID) + session_id = res.headers.get("X-Amzn-SageMaker-Closed-Session-Id") return session_id @@ -2148,7 +2146,7 @@ def test_handler_stateful(model, model_spec): # Create session session_id = create_session() if session_id is None: - raise RuntimeError("Create session failed!") + raise RuntimeError(f"Create session failed: {session_id}") spec = model_spec[args.model] if "worker" in spec: check_worker_number(spec["worker"]) @@ -2166,7 +2164,8 @@ def test_handler_stateful(model, model_spec): for stream in stream_values: req["stream"] = stream LOGGER.info(f"req {req}") - res = send_json(req, headers={HEADER_SAGEMAKER_SESSION_ID: session_id}) + res = send_json(req, + headers={"X-Amzn-SageMaker-Session-Id": session_id}) message = res.content.decode("utf-8") LOGGER.info(f"res: {message}") response_checker(res, message) @@ -2184,12 +2183,12 @@ def test_handler_stateful(model, model_spec): req, spec.get("tokenizer", None), batch_size, - headers=[f"'{HEADER_SAGEMAKER_SESSION_ID}: {session_id}'"]) + headers=[f"'X-Amzn-SageMaker-Session-Id: {session_id}'"]) # Close session closed_session_id = close_session(session_id) - if closed_session_id is None: - raise RuntimeError("Close session failed!") + if closed_session_id != session_id: + raise RuntimeError(f"Close session failed: {session_id}") def run(raw_args):