Skip to content

Commit c0632c0

Browse files
committed
[0.34.0-dlc] Fix integration test
1 parent 18ac531 commit c0632c0

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

tests/integration/llm/client.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@
1616

1717
from concurrent.futures import ThreadPoolExecutor, as_completed
1818
from json.decoder import JSONDecodeError
19-
from djl_python.session_utils import (HEADER_SAGEMAKER_SESSION_ID,
20-
HEADER_SAGEMAKER_CLOSED_SESSION_ID)
2119

2220
FAILED_DEPENDENCY_CODE = 424
2321
TIMEOUT = 3.0
@@ -1214,16 +1212,16 @@ def create_session():
12141212
res = send_json(req)
12151213
if res.status_code >= 300:
12161214
return None
1217-
session_id = res.headers.get(HEADER_SAGEMAKER_SESSION_ID).split(';')[0]
1215+
session_id = res.headers.get("X-Amzn-SageMaker-Session-Id").split(';')[0]
12181216
return session_id
12191217

12201218

12211219
def close_session(session_id):
12221220
req = {"requestType": "CLOSE"}
1223-
res = send_json(req, headers={HEADER_SAGEMAKER_SESSION_ID: session_id})
1221+
res = send_json(req, headers={"X-Amzn-SageMaker-Session-Id": session_id})
12241222
if res.status_code >= 300:
12251223
return None
1226-
session_id = res.headers.get(HEADER_SAGEMAKER_CLOSED_SESSION_ID)
1224+
session_id = res.headers.get("X-Amzn-SageMaker-Closed-Session-Id")
12271225
return session_id
12281226

12291227

@@ -2148,7 +2146,7 @@ def test_handler_stateful(model, model_spec):
21482146
# Create session
21492147
session_id = create_session()
21502148
if session_id is None:
2151-
raise RuntimeError("Create session failed!")
2149+
raise RuntimeError(f"Create session failed: {session_id}")
21522150
spec = model_spec[args.model]
21532151
if "worker" in spec:
21542152
check_worker_number(spec["worker"])
@@ -2166,7 +2164,8 @@ def test_handler_stateful(model, model_spec):
21662164
for stream in stream_values:
21672165
req["stream"] = stream
21682166
LOGGER.info(f"req {req}")
2169-
res = send_json(req, headers={HEADER_SAGEMAKER_SESSION_ID: session_id})
2167+
res = send_json(req,
2168+
headers={"X-Amzn-SageMaker-Session-Id": session_id})
21702169
message = res.content.decode("utf-8")
21712170
LOGGER.info(f"res: {message}")
21722171
response_checker(res, message)
@@ -2184,12 +2183,12 @@ def test_handler_stateful(model, model_spec):
21842183
req,
21852184
spec.get("tokenizer", None),
21862185
batch_size,
2187-
headers=[f"'{HEADER_SAGEMAKER_SESSION_ID}: {session_id}'"])
2186+
headers=[f"'X-Amzn-SageMaker-Session-Id: {session_id}'"])
21882187

21892188
# Close session
21902189
closed_session_id = close_session(session_id)
2191-
if closed_session_id is None:
2192-
raise RuntimeError("Close session failed!")
2190+
if closed_session_id != session_id:
2191+
raise RuntimeError(f"Close session failed: {session_id}")
21932192

21942193

21952194
def run(raw_args):

0 commit comments

Comments
 (0)