@@ -271,6 +271,7 @@ def _check_output_and_failure_paths(self, output_path, failure_path, waiter_conf
271271
272272 output_file_found = threading .Event ()
273273 failure_file_found = threading .Event ()
274+ waiter_error_catched = threading .Event ()
274275
275276 def check_output_file ():
276277 try :
@@ -282,7 +283,7 @@ def check_output_file():
282283 )
283284 output_file_found .set ()
284285 except WaiterError :
285- pass
286+ waiter_error_catched . set ()
286287
287288 def check_failure_file ():
288289 try :
@@ -294,33 +295,35 @@ def check_failure_file():
294295 )
295296 failure_file_found .set ()
296297 except WaiterError :
297- pass
298+ waiter_error_catched . set ()
298299
299300 output_thread = threading .Thread (target = check_output_file )
300301 failure_thread = threading .Thread (target = check_failure_file )
301302
302303 output_thread .start ()
303304 failure_thread .start ()
304305
305- while not output_file_found .is_set () and not failure_file_found .is_set ():
306+ while (
307+ not output_file_found .is_set ()
308+ and not failure_file_found .is_set ()
309+ and not waiter_error_catched .is_set ()
310+ ):
306311 time .sleep (1 )
307312
308313 if output_file_found .is_set ():
309314 s3_object = self .s3_client .get_object (Bucket = output_bucket , Key = output_key )
310315 result = self .predictor ._handle_response (response = s3_object )
311316 return result
312317
313- failure_object = self .s3_client .get_object (Bucket = failure_bucket , Key = failure_key )
314- failure_response = self .predictor ._handle_response (response = failure_object )
318+ if failure_file_found .is_set ():
319+ failure_object = self .s3_client .get_object (Bucket = failure_bucket , Key = failure_key )
320+ failure_response = self .predictor ._handle_response (response = failure_object )
321+ raise AsyncInferenceModelError (message = failure_response )
315322
316- raise (
317- AsyncInferenceModelError (message = failure_response )
318- if failure_file_found .is_set ()
319- else PollingTimeoutError (
320- message = "Inference could still be running" ,
321- output_path = output_path ,
322- seconds = waiter_config .delay * waiter_config .max_attempts ,
323- )
323+ raise PollingTimeoutError (
324+ message = "Inference could still be running" ,
325+ output_path = output_path ,
326+ seconds = waiter_config .delay * waiter_config .max_attempts ,
324327 )
325328
326329 def update_endpoint (
0 commit comments