@@ -271,6 +271,7 @@ def _check_output_and_failure_paths(self, output_path, failure_path, waiter_conf
271271
272272 output_file_found = threading .Event ()
273273 failure_file_found = threading .Event ()
274+ waiter_error_catched = threading .Event ()
274275
275276 def check_output_file ():
276277 try :
@@ -282,6 +283,7 @@ def check_output_file():
282283 )
283284 output_file_found .set ()
284285 except WaiterError :
286+ waiter_error_catched .set ()
285287 pass
286288
287289 def check_failure_file ():
@@ -294,6 +296,7 @@ def check_failure_file():
294296 )
295297 failure_file_found .set ()
296298 except WaiterError :
299+ waiter_error_catched .set ()
297300 pass
298301
299302 output_thread = threading .Thread (target = check_output_file )
@@ -302,26 +305,24 @@ def check_failure_file():
302305 output_thread .start ()
303306 failure_thread .start ()
304307
305- while not output_file_found .is_set () and not failure_file_found .is_set ():
308+ while not output_file_found .is_set () and not failure_file_found .is_set () and not waiter_error_catched . is_set () :
306309 time .sleep (1 )
307310
308311 if output_file_found .is_set ():
309312 s3_object = self .s3_client .get_object (Bucket = output_bucket , Key = output_key )
310313 result = self .predictor ._handle_response (response = s3_object )
311314 return result
312315
313- failure_object = self .s3_client .get_object (Bucket = failure_bucket , Key = failure_key )
314- failure_response = self .predictor ._handle_response (response = failure_object )
316+ if failure_file_found .is_set ():
317+ failure_object = self .s3_client .get_object (Bucket = failure_bucket , Key = failure_key )
318+ failure_response = self .predictor ._handle_response (response = failure_object )
319+ raise AsyncInferenceModelError (message = failure_response )
315320
316- raise (
317- AsyncInferenceModelError (message = failure_response )
318- if failure_file_found .is_set ()
319- else PollingTimeoutError (
321+ raise PollingTimeoutError (
320322 message = "Inference could still be running" ,
321323 output_path = output_path ,
322324 seconds = waiter_config .delay * waiter_config .max_attempts ,
323325 )
324- )
325326
326327 def update_endpoint (
327328 self ,
0 commit comments