|
12 | 12 | # language governing permissions and limitations under the License. |
13 | 13 | """Placeholder docstring""" |
14 | 14 | from __future__ import absolute_import |
15 | | - |
| 15 | +import threading |
| 16 | +import time |
16 | 17 | import uuid |
17 | 18 | from botocore.exceptions import WaiterError |
18 | | -from sagemaker.exceptions import PollingTimeoutError |
| 19 | +from sagemaker.exceptions import PollingTimeoutError, AsyncInferenceModelError |
19 | 20 | from sagemaker.async_inference import WaiterConfig, AsyncInferenceResponse |
20 | 21 | from sagemaker.s3 import parse_s3_url |
21 | 22 | from sagemaker.session import Session |
@@ -98,7 +99,10 @@ def predict( |
98 | 99 | self._input_path = input_path |
99 | 100 | response = self._submit_async_request(input_path, initial_args, inference_id) |
100 | 101 | output_location = response["OutputLocation"] |
101 | | - result = self._wait_for_output(output_path=output_location, waiter_config=waiter_config) |
| 102 | + failure_location = response["FailureLocation"] |
| 103 | + result = self._wait_for_output( |
| 104 | + output_path=output_location, failure_path=failure_location, waiter_config=waiter_config |
| 105 | + ) |
102 | 106 |
|
103 | 107 | return result |
104 | 108 |
|
@@ -141,9 +145,11 @@ def predict_async( |
141 | 145 | self._input_path = input_path |
142 | 146 | response = self._submit_async_request(input_path, initial_args, inference_id) |
143 | 147 | output_location = response["OutputLocation"] |
| 148 | + failure_location = response["FailureLocation"] |
144 | 149 | response_async = AsyncInferenceResponse( |
145 | 150 | predictor_async=self, |
146 | 151 | output_path=output_location, |
| 152 | + failure_path=failure_location, |
147 | 153 | ) |
148 | 154 |
|
149 | 155 | return response_async |
@@ -209,30 +215,81 @@ def _submit_async_request( |
209 | 215 |
|
210 | 216 | return response |
211 | 217 |
|
212 | | - def _wait_for_output( |
213 | | - self, |
214 | | - output_path, |
215 | | - waiter_config, |
216 | | - ): |
| 218 | + def _wait_for_output(self, output_path, failure_path, waiter_config): |
217 | 219 | """Check the Amazon S3 output path for the output. |
218 | 220 |
|
219 | | - Periodically check Amazon S3 output path for async inference result. |
220 | | - Timeout automatically after max attempts reached |
221 | | - """ |
222 | | - bucket, key = parse_s3_url(output_path) |
223 | | - s3_waiter = self.s3_client.get_waiter("object_exists") |
224 | | - try: |
225 | | - s3_waiter.wait(Bucket=bucket, Key=key, WaiterConfig=waiter_config._to_request_dict()) |
226 | | - except WaiterError: |
227 | | - raise PollingTimeoutError( |
228 | | - message="Inference could still be running", |
229 | | - output_path=output_path, |
230 | | - seconds=waiter_config.delay * waiter_config.max_attempts, |
231 | | - ) |
| 221 | + This method waits for either the output file or the failure file to be found on the |
| 222 | + specified S3 output path. Whichever file is found first, its corresponding event is |
| 223 | + triggered, and the method executes the appropriate action based on the event. |
232 | 224 |
|
233 | | - s3_object = self.s3_client.get_object(Bucket=bucket, Key=key) |
234 | | - result = self.predictor._handle_response(response=s3_object) |
235 | | - return result |
| 225 | + Args: |
| 226 | + output_path (str): The S3 path where the output file is expected to be found. |
| 227 | + failure_path (str): The S3 path where the failure file is expected to be found. |
| 228 | + waiter_config (boto3.waiter.WaiterConfig): The configuration for the S3 waiter. |
| 229 | +
|
| 230 | + Returns: |
| 231 | + Any: The deserialized result from the output file, if the output file is found first. |
| 232 | + Otherwise, raises an exception. |
| 233 | +
|
| 234 | + Raises: |
| 235 | + AsyncInferenceModelError: If the failure file is found before the output file. |
| 236 | + PollingTimeoutError: If both files are not found and the S3 waiter |
| 237 | + has thrown a WaiterError. |
| 238 | + """ |
| 239 | + output_bucket, output_key = parse_s3_url(output_path) |
| 240 | + failure_bucket, failure_key = parse_s3_url(failure_path) |
| 241 | + |
| 242 | + output_file_found = threading.Event() |
| 243 | + failure_file_found = threading.Event() |
| 244 | + |
| 245 | + def check_output_file(): |
| 246 | + try: |
| 247 | + output_file_waiter = self.s3_client.get_waiter("object_exists") |
| 248 | + output_file_waiter.wait( |
| 249 | + Bucket=output_bucket, |
| 250 | + Key=output_key, |
| 251 | + WaiterConfig=waiter_config._to_request_dict(), |
| 252 | + ) |
| 253 | + output_file_found.set() |
| 254 | + except WaiterError: |
| 255 | + pass |
| 256 | + |
| 257 | + def check_failure_file(): |
| 258 | + try: |
| 259 | + failure_file_waiter = self.s3_client.get_waiter("object_exists") |
| 260 | + failure_file_waiter.wait( |
| 261 | + Bucket=failure_bucket, |
| 262 | + Key=failure_key, |
| 263 | + WaiterConfig=waiter_config._to_request_dict(), |
| 264 | + ) |
| 265 | + failure_file_found.set() |
| 266 | + except WaiterError: |
| 267 | + pass |
| 268 | + |
| 269 | + output_thread = threading.Thread(target=check_output_file) |
| 270 | + failure_thread = threading.Thread(target=check_failure_file) |
| 271 | + |
| 272 | + output_thread.start() |
| 273 | + failure_thread.start() |
| 274 | + |
| 275 | + while not output_file_found.is_set() and not failure_file_found.is_set(): |
| 276 | + time.sleep(1) |
| 277 | + |
| 278 | + if output_file_found.is_set(): |
| 279 | + s3_object = self.s3_client.get_object(Bucket=output_bucket, Key=output_key) |
| 280 | + result = self.predictor._handle_response(response=s3_object) |
| 281 | + return result |
| 282 | + |
| 283 | + failure_object = self.s3_client.get_object(Bucket=failure_bucket, Key=failure_key) |
| 284 | + failure_response = self.predictor._handle_response(response=failure_object) |
| 285 | + |
| 286 | + raise AsyncInferenceModelError( |
| 287 | + message=failure_response |
| 288 | + ) if failure_file_found.is_set() else PollingTimeoutError( |
| 289 | + message="Inference could still be running", |
| 290 | + output_path=output_path, |
| 291 | + seconds=waiter_config.delay * waiter_config.max_attempts, |
| 292 | + ) |
236 | 293 |
|
237 | 294 | def update_endpoint( |
238 | 295 | self, |
|
0 commit comments