|
18 | 18 | from __future__ import annotations |
19 | 19 |
|
20 | 20 | import re |
21 | | -from collections.abc import Sequence |
| 21 | +from collections.abc import Container, Sequence |
22 | 22 | from datetime import timedelta |
23 | 23 | from functools import cached_property |
24 | 24 | from time import sleep |
|
39 | 39 | from airflow.providers.amazon.aws.utils.identifiers import generate_uuid |
40 | 40 | from airflow.providers.amazon.aws.utils.mixins import aws_template_fields |
41 | 41 | from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher |
42 | | -from airflow.providers.common.compat.sdk import AirflowException, conf |
| 42 | +from airflow.providers.common.compat.sdk import AirflowException, AirflowSkipException, conf |
43 | 43 | from airflow.utils.helpers import prune_dict |
44 | 44 |
|
45 | 45 | if TYPE_CHECKING: |
@@ -394,6 +394,9 @@ class EcsRunTaskOperator(EcsBaseOperator): |
394 | 394 | :param deferrable: If True, the operator will wait asynchronously for the job to complete. |
395 | 395 | This implies waiting for completion. This mode requires aiobotocore module to be installed. |
396 | 396 | (default: False) |
| 397 | + :param skip_on_exit_code: If task exits with this exit code, leave the task |
| 398 | + in ``skipped`` state (default: None). If set to ``None``, any non-zero |
| 399 | + exit code will be treated as a failure. Can be an int or a container of ints. |
397 | 400 | :param do_xcom_push: If True, the operator will push the ECS task ARN to XCom with key 'ecs_task_arn'. |
398 | 401 | Additionally, if logs are fetched, the last log message will be pushed to XCom with the key 'return_value'. (default: False) |
399 | 402 | :param stop_task_on_failure: If True, attempt to stop the ECS task if the Airflow task fails |
@@ -461,6 +464,7 @@ def __init__( |
461 | 464 | # Set the default waiter duration to 70 days (attempts*delay) |
462 | 465 | # Airflow execution_timeout handles task timeout |
463 | 466 | deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), |
| 467 | + skip_on_exit_code: int | Container[int] | None = None, |
464 | 468 | stop_task_on_failure: bool = True, |
465 | 469 | **kwargs, |
466 | 470 | ): |
@@ -500,6 +504,13 @@ def __init__( |
500 | 504 | self.waiter_delay = waiter_delay |
501 | 505 | self.waiter_max_attempts = waiter_max_attempts |
502 | 506 | self.deferrable = deferrable |
| 507 | + self.skip_on_exit_code = ( |
| 508 | + skip_on_exit_code |
| 509 | + if isinstance(skip_on_exit_code, Container) |
| 510 | + else [skip_on_exit_code] |
| 511 | + if skip_on_exit_code is not None |
| 512 | + else [] |
| 513 | + ) |
503 | 514 | self.stop_task_on_failure = stop_task_on_failure |
504 | 515 |
|
505 | 516 | if self._aws_logs_enabled() and not self.wait_for_completion: |
@@ -763,15 +774,21 @@ def _check_success_task(self) -> None: |
763 | 774 | containers = task["containers"] |
764 | 775 | for container in containers: |
765 | 776 | if container.get("lastStatus") == "STOPPED" and container.get("exitCode", 1) != 0: |
| 777 | + exit_code = container.get("exitCode", 1) |
| 778 | + if exit_code in self.skip_on_exit_code: |
| 779 | + exception_cls: type[AirflowException] = AirflowSkipException |
| 780 | + else: |
| 781 | + exception_cls = AirflowException |
| 782 | + |
766 | 783 | if self.task_log_fetcher: |
767 | 784 | last_logs = "\n".join( |
768 | 785 | self.task_log_fetcher.get_last_log_messages(self.number_logs_exception) |
769 | 786 | ) |
770 | | - raise AirflowException( |
| 787 | + raise exception_cls( |
771 | 788 | f"This task is not in success state - last {self.number_logs_exception} " |
772 | 789 | f"logs from Cloudwatch:\n{last_logs}" |
773 | 790 | ) |
774 | | - raise AirflowException(f"This task is not in success state {task}") |
| 791 | + raise exception_cls(f"This task is not in success state {task}") |
775 | 792 | if container.get("lastStatus") == "PENDING": |
776 | 793 | raise AirflowException(f"This task is still pending {task}") |
777 | 794 | if "error" in container.get("reason", "").lower(): |
|
0 commit comments