Skip to content

Commit 8686c89

Browse files
authored
Add wait/defer support - MwaaTriggerDagRunOperator (#47528)
* Add wait/defer support - MwaaTriggerDagRunOperator Depends on #47527 because of the MwaaDagRunCompletedTrigger
1 parent d1a44f0 commit 8686c89

File tree

3 files changed

+116
-11
lines changed

3 files changed

+116
-11
lines changed

providers/amazon/src/airflow/providers/amazon/aws/operators/mwaa.py

Lines changed: 73 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,14 @@
1919
from __future__ import annotations
2020

2121
from collections.abc import Sequence
22-
from typing import TYPE_CHECKING
22+
from typing import TYPE_CHECKING, Any
2323

24+
from airflow.configuration import conf
25+
from airflow.exceptions import AirflowException
2426
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
2527
from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
28+
from airflow.providers.amazon.aws.triggers.mwaa import MwaaDagRunCompletedTrigger
29+
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
2630
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
2731

2832
if TYPE_CHECKING:
@@ -48,6 +52,23 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
4852
:param conf: Additional configuration parameters. The value of this field can be set only when creating
4953
the object. (templated)
5054
:param note: Contains manually entered notes by the user about the DagRun. (templated)
55+
56+
:param wait_for_completion: Whether to wait for DAG run to stop. (default: False)
57+
:param waiter_delay: Time in seconds to wait between status checks. (default: 120)
58+
:param waiter_max_attempts: Maximum number of attempts to check for DAG run completion. (default: 720)
59+
:param deferrable: If True, the operator will wait asynchronously for the DAG run to stop.
60+
This implies waiting for completion. This mode requires aiobotocore module to be installed.
61+
(default: False)
62+
:param aws_conn_id: The Airflow connection used for AWS credentials.
63+
If this is ``None`` or empty then the default boto3 behaviour is used. If
64+
running Airflow in a distributed manner and aws_conn_id is None or
65+
empty, then default boto3 configuration would be used (and must be
66+
maintained on each worker node).
67+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
68+
:param verify: Whether or not to verify SSL certificates. See:
69+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
70+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
71+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
5172
"""
5273

5374
aws_hook_class = MwaaHook
@@ -74,6 +95,10 @@ def __init__(
7495
data_interval_end: str | None = None,
7596
conf: dict | None = None,
7697
note: str | None = None,
98+
wait_for_completion: bool = False,
99+
waiter_delay: int = 60,
100+
waiter_max_attempts: int = 720,
101+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
77102
**kwargs,
78103
):
79104
super().__init__(**kwargs)
@@ -85,6 +110,21 @@ def __init__(
85110
self.data_interval_end = data_interval_end
86111
self.conf = conf if conf else {}
87112
self.note = note
113+
self.wait_for_completion = wait_for_completion
114+
self.waiter_delay = waiter_delay
115+
self.waiter_max_attempts = waiter_max_attempts
116+
self.deferrable = deferrable
117+
118+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict:
119+
validated_event = validate_execute_complete_event(event)
120+
if validated_event["status"] != "success":
121+
raise AirflowException(f"DAG run failed: {validated_event}")
122+
123+
dag_run_id = validated_event["dag_run_id"]
124+
self.log.info("DAG run %s of DAG %s completed", dag_run_id, self.trigger_dag_id)
125+
return self.hook.invoke_rest_api(
126+
env_name=self.env_name, path=f"/dags/{self.trigger_dag_id}/dagRuns/{dag_run_id}", method="GET"
127+
)
88128

89129
def execute(self, context: Context) -> dict:
90130
"""
@@ -94,7 +134,7 @@ def execute(self, context: Context) -> dict:
94134
:return: dict with information about the Dag run
95135
For details of the returned dict, see :py:meth:`botocore.client.MWAA.invoke_rest_api`
96136
"""
97-
return self.hook.invoke_rest_api(
137+
response = self.hook.invoke_rest_api(
98138
env_name=self.env_name,
99139
path=f"/dags/{self.trigger_dag_id}/dagRuns",
100140
method="POST",
@@ -107,3 +147,34 @@ def execute(self, context: Context) -> dict:
107147
"note": self.note,
108148
},
109149
)
150+
151+
dag_run_id = response["RestApiResponse"]["dag_run_id"]
152+
self.log.info("DAG run %s of DAG %s created", dag_run_id, self.trigger_dag_id)
153+
154+
task_description = f"DAG run {dag_run_id} of DAG {self.trigger_dag_id} to complete"
155+
if self.deferrable:
156+
self.log.info("Deferring for %s", task_description)
157+
self.defer(
158+
trigger=MwaaDagRunCompletedTrigger(
159+
external_env_name=self.env_name,
160+
external_dag_id=self.trigger_dag_id,
161+
external_dag_run_id=dag_run_id,
162+
waiter_delay=self.waiter_delay,
163+
waiter_max_attempts=self.waiter_max_attempts,
164+
aws_conn_id=self.aws_conn_id,
165+
),
166+
method_name="execute_complete",
167+
)
168+
elif self.wait_for_completion:
169+
self.log.info("Waiting for %s", task_description)
170+
api_kwargs = {
171+
"Name": self.env_name,
172+
"Path": f"/dags/{self.trigger_dag_id}/dagRuns/{dag_run_id}",
173+
"Method": "GET",
174+
}
175+
self.hook.get_waiter("mwaa_dag_run_complete").wait(
176+
**api_kwargs,
177+
WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts},
178+
)
179+
180+
return response

providers/amazon/src/airflow/providers/amazon/aws/triggers/mwaa.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,21 +109,20 @@ def hook(self) -> AwsGenericHook:
109109
def _build_waiter_acceptors(
110110
success_states: set[str], failure_states: set[str], in_progress_states: set[str]
111111
) -> list:
112-
def build_acceptor(dag_run_state: str, state_waiter_category: str):
113-
return {
114-
"matcher": "path",
115-
"argument": "RestApiResponse.state",
116-
"expected": dag_run_state,
117-
"state": state_waiter_category,
118-
}
119-
120112
acceptors = []
121113
for state_set, state_waiter_category in (
122114
(success_states, "success"),
123115
(failure_states, "failure"),
124116
(in_progress_states, "retry"),
125117
):
126118
for dag_run_state in state_set:
127-
acceptors.append(build_acceptor(dag_run_state, state_waiter_category))
119+
acceptors.append(
120+
{
121+
"matcher": "path",
122+
"argument": "RestApiResponse.state",
123+
"expected": dag_run_state,
124+
"state": state_waiter_category,
125+
}
126+
)
128127

129128
return acceptors

providers/amazon/tests/unit/amazon/aws/operators/test_mwaa.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
from __future__ import annotations
1818

1919
from unittest import mock
20+
from unittest.mock import MagicMock
2021

22+
import pytest
23+
24+
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
2125
from airflow.providers.amazon.aws.operators.mwaa import MwaaTriggerDagRunOperator
2226
from unit.amazon.aws.utils.test_template_fields import validate_template_fields
2327

@@ -31,6 +35,10 @@
3135
"data_interval_end": "2025-01-03T00:00:01Z",
3236
"conf": {"key": "value"},
3337
"note": "test note",
38+
"wait_for_completion": False,
39+
"waiter_delay": 5,
40+
"waiter_max_attempts": 20,
41+
"deferrable": False,
3442
}
3543
HOOK_RETURN_VALUE = {
3644
"ResponseMetadata": {},
@@ -53,6 +61,10 @@ def test_init(self):
5361
assert op.data_interval_end == OP_KWARGS["data_interval_end"]
5462
assert op.conf == OP_KWARGS["conf"]
5563
assert op.note == OP_KWARGS["note"]
64+
assert op.wait_for_completion == OP_KWARGS["wait_for_completion"]
65+
assert op.waiter_delay == OP_KWARGS["waiter_delay"]
66+
assert op.waiter_max_attempts == OP_KWARGS["waiter_max_attempts"]
67+
assert op.deferrable == OP_KWARGS["deferrable"]
5668

5769
@mock.patch.object(MwaaTriggerDagRunOperator, "hook")
5870
def test_execute(self, mock_hook):
@@ -78,3 +90,26 @@ def test_execute(self, mock_hook):
7890
def test_template_fields(self):
7991
operator = MwaaTriggerDagRunOperator(**OP_KWARGS)
8092
validate_template_fields(operator)
93+
94+
@pytest.mark.parametrize(
95+
"wait_for_completion, deferrable",
96+
[
97+
pytest.param(False, False, id="no_wait"),
98+
pytest.param(True, False, id="wait"),
99+
pytest.param(False, True, id="defer"),
100+
],
101+
)
102+
@mock.patch.object(MwaaHook, "get_waiter")
103+
@mock.patch.object(MwaaTriggerDagRunOperator, "hook")
104+
def test_execute_wait_combinations(self, mock_hook, _, wait_for_completion, deferrable):
105+
kwargs = OP_KWARGS
106+
kwargs["wait_for_completion"] = wait_for_completion
107+
kwargs["deferrable"] = deferrable
108+
op = MwaaTriggerDagRunOperator(**OP_KWARGS)
109+
mock_hook.invoke_rest_api.return_value = HOOK_RETURN_VALUE
110+
op.defer = MagicMock()
111+
response = op.execute({})
112+
113+
assert response == HOOK_RETURN_VALUE
114+
assert mock_hook.get_waiter.call_count == wait_for_completion
115+
assert op.defer.call_count == deferrable

0 commit comments

Comments
 (0)