Skip to content

Commit a998e28

Browse files
authored
Add skip_on_exit_code support to EcsRunTaskOperator (#63274)
Allow users to specify exit codes that should raise an AirflowSkipException (marking the task as skipped) via the new `skip_on_exit_code` parameter. This is consistent with the existing behavior in DockerOperator and KubernetesPodOperator.
1 parent 7606f82 commit a998e28

File tree

2 files changed

+57
-5
lines changed
  • providers/amazon
    • src/airflow/providers/amazon/aws/operators
    • tests/unit/amazon/aws/operators

2 files changed

+57
-5
lines changed

providers/amazon/src/airflow/providers/amazon/aws/operators/ecs.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import annotations
1919

2020
import re
21-
from collections.abc import Sequence
21+
from collections.abc import Container, Sequence
2222
from datetime import timedelta
2323
from functools import cached_property
2424
from time import sleep
@@ -39,7 +39,7 @@
3939
from airflow.providers.amazon.aws.utils.identifiers import generate_uuid
4040
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
4141
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
42-
from airflow.providers.common.compat.sdk import AirflowException, conf
42+
from airflow.providers.common.compat.sdk import AirflowException, AirflowSkipException, conf
4343
from airflow.utils.helpers import prune_dict
4444

4545
if TYPE_CHECKING:
@@ -394,6 +394,9 @@ class EcsRunTaskOperator(EcsBaseOperator):
394394
:param deferrable: If True, the operator will wait asynchronously for the job to complete.
395395
This implies waiting for completion. This mode requires aiobotocore module to be installed.
396396
(default: False)
397+
:param skip_on_exit_code: If task exits with this exit code, leave the task
398+
in ``skipped`` state (default: None). If set to ``None``, any non-zero
399+
exit code will be treated as a failure. Can be an int or a container of ints.
397400
:param do_xcom_push: If True, the operator will push the ECS task ARN to XCom with key 'ecs_task_arn'.
398401
Additionally, if logs are fetched, the last log message will be pushed to XCom with the key 'return_value'. (default: False)
399402
:param stop_task_on_failure: If True, attempt to stop the ECS task if the Airflow task fails
@@ -461,6 +464,7 @@ def __init__(
461464
# Set the default waiter duration to 70 days (attempts*delay)
462465
# Airflow execution_timeout handles task timeout
463466
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
467+
skip_on_exit_code: int | Container[int] | None = None,
464468
stop_task_on_failure: bool = True,
465469
**kwargs,
466470
):
@@ -500,6 +504,13 @@ def __init__(
500504
self.waiter_delay = waiter_delay
501505
self.waiter_max_attempts = waiter_max_attempts
502506
self.deferrable = deferrable
507+
self.skip_on_exit_code = (
508+
skip_on_exit_code
509+
if isinstance(skip_on_exit_code, Container)
510+
else [skip_on_exit_code]
511+
if skip_on_exit_code is not None
512+
else []
513+
)
503514
self.stop_task_on_failure = stop_task_on_failure
504515

505516
if self._aws_logs_enabled() and not self.wait_for_completion:
@@ -763,15 +774,21 @@ def _check_success_task(self) -> None:
763774
containers = task["containers"]
764775
for container in containers:
765776
if container.get("lastStatus") == "STOPPED" and container.get("exitCode", 1) != 0:
777+
exit_code = container.get("exitCode", 1)
778+
if exit_code in self.skip_on_exit_code:
779+
exception_cls: type[AirflowException] = AirflowSkipException
780+
else:
781+
exception_cls = AirflowException
782+
766783
if self.task_log_fetcher:
767784
last_logs = "\n".join(
768785
self.task_log_fetcher.get_last_log_messages(self.number_logs_exception)
769786
)
770-
raise AirflowException(
787+
raise exception_cls(
771788
f"This task is not in success state - last {self.number_logs_exception} "
772789
f"logs from Cloudwatch:\n{last_logs}"
773790
)
774-
raise AirflowException(f"This task is not in success state {task}")
791+
raise exception_cls(f"This task is not in success state {task}")
775792
if container.get("lastStatus") == "PENDING":
776793
raise AirflowException(f"This task is still pending {task}")
777794
if "error" in container.get("reason", "").lower():

providers/amazon/tests/unit/amazon/aws/operators/test_ecs.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from airflow.providers.amazon.aws.triggers.ecs import TaskDoneTrigger
3939
from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher
4040
from airflow.providers.amazon.version_compat import NOTSET
41-
from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred
41+
from airflow.providers.common.compat.sdk import AirflowException, AirflowSkipException, TaskDeferred
4242

4343
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
4444

@@ -603,6 +603,41 @@ def test_check_success_task_not_raises(self, client_mock):
603603
self.ecs._check_success_task()
604604
client_mock.describe_tasks.assert_called_once_with(cluster="c", tasks=["arn"])
605605

606+
@mock.patch.object(EcsBaseOperator, "client")
607+
def test_check_success_task_raises_skip_exception(self, client_mock):
608+
self.ecs.arn = "arn"
609+
self.ecs.skip_on_exit_code = [2]
610+
client_mock.describe_tasks.return_value = {
611+
"tasks": [{"containers": [{"name": "container-name", "lastStatus": "STOPPED", "exitCode": 2}]}]
612+
}
613+
with pytest.raises(AirflowSkipException):
614+
self.ecs._check_success_task()
615+
616+
@mock.patch.object(EcsBaseOperator, "client")
617+
@mock.patch("airflow.providers.amazon.aws.utils.task_log_fetcher.AwsTaskLogFetcher")
618+
def test_check_success_task_skip_exception_with_logs(self, log_fetcher_mock, client_mock):
619+
self.ecs.arn = "arn"
620+
self.ecs.skip_on_exit_code = [2]
621+
self.ecs.task_log_fetcher = log_fetcher_mock
622+
log_fetcher_mock.get_last_log_messages.return_value = ["log1", "log2"]
623+
client_mock.describe_tasks.return_value = {
624+
"tasks": [{"containers": [{"name": "container-name", "lastStatus": "STOPPED", "exitCode": 2}]}]
625+
}
626+
with pytest.raises(AirflowSkipException, match="This task is not in success state"):
627+
self.ecs._check_success_task()
628+
629+
@mock.patch.object(EcsBaseOperator, "client")
630+
def test_check_success_task_unmatched_exit_code_raises_airflow_exception(self, client_mock):
631+
"""Exit codes not in skip_on_exit_code raise AirflowException."""
632+
self.ecs.arn = "arn"
633+
self.ecs.skip_on_exit_code = [2]
634+
client_mock.describe_tasks.return_value = {
635+
"tasks": [{"containers": [{"name": "container-name", "lastStatus": "STOPPED", "exitCode": 1}]}]
636+
}
637+
with pytest.raises(AirflowException) as ctx:
638+
self.ecs._check_success_task()
639+
assert type(ctx.value) is AirflowException
640+
606641
@pytest.mark.parametrize(
607642
("launch_type", "tags"),
608643
[

0 commit comments

Comments
 (0)