Skip to content

Commit c4f0bac

Browse files
author
Ubuntu
committed
addWaiterTimeoutHandling
1 parent 9e7f666 commit c4f0bac

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

src/sagemaker/predictor_async.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ def _check_output_and_failure_paths(self, output_path, failure_path, waiter_conf
271271

272272
output_file_found = threading.Event()
273273
failure_file_found = threading.Event()
274+
waiter_error_catched = threading.Event()
274275

275276
def check_output_file():
276277
try:
@@ -282,6 +283,7 @@ def check_output_file():
282283
)
283284
output_file_found.set()
284285
except WaiterError:
286+
waiter_error_catched.set()
285287
pass
286288

287289
def check_failure_file():
@@ -294,6 +296,7 @@ def check_failure_file():
294296
)
295297
failure_file_found.set()
296298
except WaiterError:
299+
waiter_error_catched.set()
297300
pass
298301

299302
output_thread = threading.Thread(target=check_output_file)
@@ -302,26 +305,24 @@ def check_failure_file():
302305
output_thread.start()
303306
failure_thread.start()
304307

305-
while not output_file_found.is_set() and not failure_file_found.is_set():
308+
while not output_file_found.is_set() and not failure_file_found.is_set() and not waiter_error_catched.is_set():
306309
time.sleep(1)
307310

308311
if output_file_found.is_set():
309312
s3_object = self.s3_client.get_object(Bucket=output_bucket, Key=output_key)
310313
result = self.predictor._handle_response(response=s3_object)
311314
return result
312315

313-
failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key)
314-
failure_response = self.predictor._handle_response(response=failure_object)
316+
if failure_file_found.is_set():
317+
failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key)
318+
failure_response = self.predictor._handle_response(response=failure_object)
319+
raise AsyncInferenceModelError(message=failure_response)
315320

316-
raise (
317-
AsyncInferenceModelError(message=failure_response)
318-
if failure_file_found.is_set()
319-
else PollingTimeoutError(
321+
raise PollingTimeoutError(
320322
message="Inference could still be running",
321323
output_path=output_path,
322324
seconds=waiter_config.delay * waiter_config.max_attempts,
323325
)
324-
)
325326

326327
def update_endpoint(
327328
self,

0 commit comments

Comments
 (0)