1616
1717from concurrent .futures import ThreadPoolExecutor , as_completed
1818from json .decoder import JSONDecodeError
19- from djl_python .session_utils import (HEADER_SAGEMAKER_SESSION_ID ,
20- HEADER_SAGEMAKER_CLOSED_SESSION_ID )
2119
2220FAILED_DEPENDENCY_CODE = 424
2321TIMEOUT = 3.0
@@ -1214,16 +1212,16 @@ def create_session():
12141212 res = send_json (req )
12151213 if res .status_code >= 300 :
12161214 return None
1217- session_id = res .headers .get (HEADER_SAGEMAKER_SESSION_ID ).split (';' )[0 ]
1215+ session_id = res .headers .get ("X-Amzn-SageMaker-Session-Id" ).split (';' )[0 ]
12181216 return session_id
12191217
12201218
12211219def close_session (session_id ):
12221220 req = {"requestType" : "CLOSE" }
1223- res = send_json (req , headers = {HEADER_SAGEMAKER_SESSION_ID : session_id })
1221+ res = send_json (req , headers = {"X-Amzn-SageMaker-Session-Id" : session_id })
12241222 if res .status_code >= 300 :
12251223 return None
1226- session_id = res .headers .get (HEADER_SAGEMAKER_CLOSED_SESSION_ID )
1224+ session_id = res .headers .get ("X-Amzn-SageMaker-Closed-Session-Id" )
12271225 return session_id
12281226
12291227
@@ -2166,7 +2164,8 @@ def test_handler_stateful(model, model_spec):
21662164 for stream in stream_values :
21672165 req ["stream" ] = stream
21682166 LOGGER .info (f"req { req } " )
2169- res = send_json (req , headers = {HEADER_SAGEMAKER_SESSION_ID : session_id })
2167+ res = send_json (req ,
2168+ headers = {"X-Amzn-SageMaker-Session-Id" : session_id })
21702169 message = res .content .decode ("utf-8" )
21712170 LOGGER .info (f"res: { message } " )
21722171 response_checker (res , message )
@@ -2184,11 +2183,11 @@ def test_handler_stateful(model, model_spec):
21842183 req ,
21852184 spec .get ("tokenizer" , None ),
21862185 batch_size ,
2187- headers = [f"'{ HEADER_SAGEMAKER_SESSION_ID } : { session_id } '" ])
2186+ headers = [f"'X-Amzn-SageMaker-Session-Id : { session_id } '" ])
21882187
21892188 # Close session
21902189 closed_session_id = close_session (session_id )
2191- if closed_session_id is None :
2190+ if closed_session_id != session_id :
21922191 raise RuntimeError ("Close session failed!" )
21932192
21942193
0 commit comments