@@ -271,6 +271,7 @@ def _check_output_and_failure_paths(self, output_path, failure_path, waiter_conf
271
271
272
272
output_file_found = threading .Event ()
273
273
failure_file_found = threading .Event ()
274
+ waiter_error_catched = threading .Event ()
274
275
275
276
def check_output_file ():
276
277
try :
@@ -282,6 +283,7 @@ def check_output_file():
282
283
)
283
284
output_file_found .set ()
284
285
except WaiterError :
286
+ waiter_error_catched .set ()
285
287
pass
286
288
287
289
def check_failure_file ():
@@ -294,6 +296,7 @@ def check_failure_file():
294
296
)
295
297
failure_file_found .set ()
296
298
except WaiterError :
299
+ waiter_error_catched .set ()
297
300
pass
298
301
299
302
output_thread = threading .Thread (target = check_output_file )
@@ -302,26 +305,24 @@ def check_failure_file():
302
305
output_thread .start ()
303
306
failure_thread .start ()
304
307
305
- while not output_file_found .is_set () and not failure_file_found .is_set ():
308
+ while not output_file_found .is_set () and not failure_file_found .is_set () and not waiter_error_catched . is_set () :
306
309
time .sleep (1 )
307
310
308
311
if output_file_found .is_set ():
309
312
s3_object = self .s3_client .get_object (Bucket = output_bucket , Key = output_key )
310
313
result = self .predictor ._handle_response (response = s3_object )
311
314
return result
312
315
313
- failure_object = self .s3_client .get_object (Bucket = failure_bucket , Key = failure_key )
314
- failure_response = self .predictor ._handle_response (response = failure_object )
316
+ if failure_file_found .is_set ():
317
+ failure_object = self .s3_client .get_object (Bucket = failure_bucket , Key = failure_key )
318
+ failure_response = self .predictor ._handle_response (response = failure_object )
319
+ raise AsyncInferenceModelError (message = failure_response )
315
320
316
- raise (
317
- AsyncInferenceModelError (message = failure_response )
318
- if failure_file_found .is_set ()
319
- else PollingTimeoutError (
321
+ raise PollingTimeoutError (
320
322
message = "Inference could still be running" ,
321
323
output_path = output_path ,
322
324
seconds = waiter_config .delay * waiter_config .max_attempts ,
323
325
)
324
- )
325
326
326
327
def update_endpoint (
327
328
self ,
0 commit comments