Skip to content

Commit b6f33d4

Browse files
ramitkatarianailo2c
authored andcommitted
Add deferrable support for MwaaDagRunSensor (apache#47527)
* Add MwaaDagRunCompletedTrigger and deferrable support for MwaaDagRunSensor Also Includes: - Unit tests - Support for `AwsGenericHook` and `AwsBaseWaiterTrigger` to allow overriding boto waiter config for custom success and failure states in the sensor - Changes to `aws.utils.waiter_with_logging.async_wait` to include info about latest response in exception
1 parent 0e13d00 commit b6f33d4

File tree

10 files changed

+394
-38
lines changed

10 files changed

+394
-38
lines changed

providers/amazon/provider.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,9 @@ triggers:
702702
- integration-name: AWS Lambda
703703
python-modules:
704704
- airflow.providers.amazon.aws.triggers.lambda_function
705+
- integration-name: Amazon Managed Workflows for Apache Airflow (MWAA)
706+
python-modules:
707+
- airflow.providers.amazon.aws.triggers.mwaa
705708
- integration-name: Amazon Managed Service for Apache Flink
706709
python-modules:
707710
- airflow.providers.amazon.aws.triggers.kinesis_analytics

providers/amazon/src/airflow/providers/amazon/aws/hooks/base_aws.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,7 @@ def get_waiter(
943943
self,
944944
waiter_name: str,
945945
parameters: dict[str, str] | None = None,
946+
config_overrides: dict[str, Any] | None = None,
946947
deferrable: bool = False,
947948
client=None,
948949
) -> Waiter:
@@ -962,6 +963,9 @@ def get_waiter(
962963
:param parameters: will scan the waiter config for the keys of that dict,
963964
and replace them with the corresponding value. If a custom waiter has
964965
such keys to be expanded, they need to be provided here.
966+
Note: cannot be used if parameters are included in config_overrides
967+
:param config_overrides: will update values of provided keys in the waiter's
968+
config. Only specified keys will be updated.
965969
:param deferrable: If True, the waiter is going to be an async custom waiter.
966970
An async client must be provided in that case.
967971
:param client: The client to use for the waiter's operations
@@ -970,14 +974,18 @@ def get_waiter(
970974

971975
if deferrable and not client:
972976
raise ValueError("client must be provided for a deferrable waiter.")
977+
if parameters is not None and config_overrides is not None and "acceptors" in config_overrides:
978+
raise ValueError('parameters must be None when "acceptors" is included in config_overrides')
973979
# Currently, the custom waiter doesn't work with resource_type, only client_type is supported.
974980
client = client or self._client
975981
if self.waiter_path and (waiter_name in self._list_custom_waiters()):
976982
# Technically if waiter_name is in custom_waiters then self.waiter_path must
977983
# exist but MyPy doesn't like the fact that self.waiter_path could be None.
978984
with open(self.waiter_path) as config_file:
979-
config = json.loads(config_file.read())
985+
config: dict = json.loads(config_file.read())
980986

987+
if config_overrides is not None:
988+
config["waiters"][waiter_name].update(config_overrides)
981989
config = self._apply_parameters_value(config, waiter_name, parameters)
982990
return BaseBotoWaiter(client=client, model_config=config, deferrable=deferrable).waiter(
983991
waiter_name

providers/amazon/src/airflow/providers/amazon/aws/sensors/mwaa.py

Lines changed: 58 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@
1818
from __future__ import annotations
1919

2020
from collections.abc import Collection, Sequence
21-
from typing import TYPE_CHECKING
21+
from typing import TYPE_CHECKING, Any
2222

23+
from airflow.configuration import conf
2324
from airflow.exceptions import AirflowException
2425
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
2526
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
27+
from airflow.providers.amazon.aws.triggers.mwaa import MwaaDagRunCompletedTrigger
28+
from airflow.providers.amazon.aws.utils import validate_execute_complete_event
2629
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
27-
from airflow.utils.state import State
30+
from airflow.utils.state import DagRunState
2831

2932
if TYPE_CHECKING:
3033
from airflow.utils.context import Context
@@ -46,9 +49,24 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
4649
(templated)
4750
:param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated)
4851
:param success_states: Collection of DAG Run states that would make this task marked as successful, default is
49-
``airflow.utils.state.State.success_states`` (templated)
52+
``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
5053
:param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
51-
AirflowException, default is ``airflow.utils.state.State.failed_states`` (templated)
54+
AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated)
55+
:param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore
56+
module to be installed.
57+
(default: False, but can be overridden in config file by setting default_deferrable to True)
58+
:param poke_interval: Polling period in seconds to check for the status of the job. (default: 60)
59+
:param max_retries: Number of times before returning the current state. (default: 720)
60+
:param aws_conn_id: The Airflow connection used for AWS credentials.
61+
If this is ``None`` or empty then the default boto3 behaviour is used. If
62+
running Airflow in a distributed manner and aws_conn_id is None or
63+
empty, then default boto3 configuration would be used (and must be
64+
maintained on each worker node).
65+
:param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
66+
:param verify: Whether or not to verify SSL certificates. See:
67+
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
68+
:param botocore_config: Configuration dictionary (key-values) for botocore client. See:
69+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
5270
"""
5371

5472
aws_hook_class = MwaaHook
@@ -58,6 +76,9 @@ class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
5876
"external_dag_run_id",
5977
"success_states",
6078
"failure_states",
79+
"deferrable",
80+
"max_retries",
81+
"poke_interval",
6182
)
6283

6384
def __init__(
@@ -68,19 +89,25 @@ def __init__(
6889
external_dag_run_id: str,
6990
success_states: Collection[str] | None = None,
7091
failure_states: Collection[str] | None = None,
92+
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
93+
poke_interval: int = 60,
94+
max_retries: int = 720,
7195
**kwargs,
7296
):
7397
super().__init__(**kwargs)
7498

75-
self.success_states = set(success_states if success_states else State.success_states)
76-
self.failure_states = set(failure_states if failure_states else State.failed_states)
99+
self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value}
100+
self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value}
77101

78102
if len(self.success_states & self.failure_states):
79-
raise AirflowException("allowed_states and failed_states must not have any values in common")
103+
raise ValueError("success_states and failure_states must not have any values in common")
80104

81105
self.external_env_name = external_env_name
82106
self.external_dag_id = external_dag_id
83107
self.external_dag_run_id = external_dag_run_id
108+
self.deferrable = deferrable
109+
self.poke_interval = poke_interval
110+
self.max_retries = max_retries
84111

85112
def poke(self, context: Context) -> bool:
86113
self.log.info(
@@ -102,12 +129,32 @@ def poke(self, context: Context) -> bool:
102129
# The scope of this sensor is going to only be raising AirflowException due to failure of the DAGRun
103130

104131
state = response["RestApiResponse"]["state"]
105-
if state in self.success_states:
106-
return True
107132

108133
if state in self.failure_states:
109134
raise AirflowException(
110135
f"The DAG run {self.external_dag_run_id} of DAG {self.external_dag_id} in MWAA environment {self.external_env_name} "
111-
f"failed with state {state}."
136+
f"failed with state: {state}"
112137
)
113-
return False
138+
139+
return state in self.success_states
140+
141+
def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None:
142+
validate_execute_complete_event(event)
143+
144+
def execute(self, context: Context):
145+
if self.deferrable:
146+
self.defer(
147+
trigger=MwaaDagRunCompletedTrigger(
148+
external_env_name=self.external_env_name,
149+
external_dag_id=self.external_dag_id,
150+
external_dag_run_id=self.external_dag_run_id,
151+
success_states=self.success_states,
152+
failure_states=self.failure_states,
153+
waiter_delay=self.poke_interval,
154+
waiter_max_attempts=self.max_retries,
155+
aws_conn_id=self.aws_conn_id,
156+
),
157+
method_name="execute_complete",
158+
)
159+
else:
160+
super().execute(context=context)

providers/amazon/src/airflow/providers/amazon/aws/triggers/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ class AwsBaseWaiterTrigger(BaseTrigger):
5555
5656
:param waiter_delay: The amount of time in seconds to wait between attempts.
5757
:param waiter_max_attempts: The maximum number of attempts to be made.
58+
:param waiter_config_overrides: A dict to update waiter's default configuration. Only specified keys will
59+
be updated.
5860
:param aws_conn_id: The Airflow connection used for AWS credentials. To be used to build the hook.
5961
:param region_name: The AWS region where the resources to watch are. To be used to build the hook.
6062
:param verify: Whether or not to verify SSL certificates. To be used to build the hook.
@@ -77,6 +79,7 @@ def __init__(
7779
return_value: Any,
7880
waiter_delay: int,
7981
waiter_max_attempts: int,
82+
waiter_config_overrides: dict[str, Any] | None = None,
8083
aws_conn_id: str | None,
8184
region_name: str | None = None,
8285
verify: bool | str | None = None,
@@ -91,6 +94,7 @@ def __init__(
9194
self.failure_message = failure_message
9295
self.status_message = status_message
9396
self.status_queries = status_queries
97+
self.waiter_config_overrides = waiter_config_overrides
9498

9599
self.return_key = return_key
96100
self.return_value = return_value
@@ -140,7 +144,12 @@ def hook(self) -> AwsGenericHook:
140144
async def run(self) -> AsyncIterator[TriggerEvent]:
141145
hook = self.hook()
142146
async with await hook.get_async_conn() as client:
143-
waiter = hook.get_waiter(self.waiter_name, deferrable=True, client=client)
147+
waiter = hook.get_waiter(
148+
self.waiter_name,
149+
deferrable=True,
150+
client=client,
151+
config_overrides=self.waiter_config_overrides,
152+
)
144153
await async_wait(
145154
waiter,
146155
self.waiter_delay,
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from collections.abc import Collection
21+
from typing import TYPE_CHECKING
22+
23+
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
24+
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
25+
from airflow.utils.state import DagRunState
26+
27+
if TYPE_CHECKING:
28+
from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
29+
30+
31+
class MwaaDagRunCompletedTrigger(AwsBaseWaiterTrigger):
32+
"""
33+
Trigger when an MWAA Dag Run is complete.
34+
35+
:param external_env_name: The external MWAA environment name that contains the DAG Run you want to wait for
36+
(templated)
37+
:param external_dag_id: The DAG ID in the external MWAA environment that contains the DAG Run you want to wait for
38+
(templated)
39+
:param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated)
40+
:param success_states: Collection of DAG Run states that would make this task marked as successful, default is
41+
``{airflow.utils.state.DagRunState.SUCCESS}`` (templated)
42+
:param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
43+
AirflowException, default is ``{airflow.utils.state.DagRunState.FAILED}`` (templated)
44+
:param waiter_delay: The amount of time in seconds to wait between attempts. (default: 60)
45+
:param waiter_max_attempts: The maximum number of attempts to be made. (default: 720)
46+
:param aws_conn_id: The Airflow connection used for AWS credentials.
47+
"""
48+
49+
def __init__(
50+
self,
51+
*,
52+
external_env_name: str,
53+
external_dag_id: str,
54+
external_dag_run_id: str,
55+
success_states: Collection[str] | None = None,
56+
failure_states: Collection[str] | None = None,
57+
waiter_delay: int = 60,
58+
waiter_max_attempts: int = 720,
59+
aws_conn_id: str | None = None,
60+
) -> None:
61+
self.success_states = set(success_states) if success_states else {DagRunState.SUCCESS.value}
62+
self.failure_states = set(failure_states) if failure_states else {DagRunState.FAILED.value}
63+
64+
if len(self.success_states & self.failure_states):
65+
raise ValueError("success_states and failure_states must not have any values in common")
66+
67+
in_progress_states = {s.value for s in DagRunState} - self.success_states - self.failure_states
68+
69+
super().__init__(
70+
serialized_fields={
71+
"external_env_name": external_env_name,
72+
"external_dag_id": external_dag_id,
73+
"external_dag_run_id": external_dag_run_id,
74+
"success_states": success_states,
75+
"failure_states": failure_states,
76+
},
77+
waiter_name="mwaa_dag_run_complete",
78+
waiter_args={
79+
"Name": external_env_name,
80+
"Path": f"/dags/{external_dag_id}/dagRuns/{external_dag_run_id}",
81+
"Method": "GET",
82+
},
83+
failure_message=f"The DAG run {external_dag_run_id} of DAG {external_dag_id} in MWAA environment {external_env_name} failed with state",
84+
status_message="State of DAG run",
85+
status_queries=["RestApiResponse.state"],
86+
return_key="dag_run_id",
87+
return_value=external_dag_run_id,
88+
waiter_delay=waiter_delay,
89+
waiter_max_attempts=waiter_max_attempts,
90+
aws_conn_id=aws_conn_id,
91+
waiter_config_overrides={
92+
"acceptors": _build_waiter_acceptors(
93+
success_states=self.success_states,
94+
failure_states=self.failure_states,
95+
in_progress_states=in_progress_states,
96+
)
97+
},
98+
)
99+
100+
def hook(self) -> AwsGenericHook:
101+
return MwaaHook(
102+
aws_conn_id=self.aws_conn_id,
103+
region_name=self.region_name,
104+
verify=self.verify,
105+
config=self.botocore_config,
106+
)
107+
108+
109+
def _build_waiter_acceptors(
110+
success_states: set[str], failure_states: set[str], in_progress_states: set[str]
111+
) -> list:
112+
def build_acceptor(dag_run_state: str, state_waiter_category: str):
113+
return {
114+
"matcher": "path",
115+
"argument": "RestApiResponse.state",
116+
"expected": dag_run_state,
117+
"state": state_waiter_category,
118+
}
119+
120+
acceptors = []
121+
for state_set, state_waiter_category in (
122+
(success_states, "success"),
123+
(failure_states, "failure"),
124+
(in_progress_states, "retry"),
125+
):
126+
for dag_run_state in state_set:
127+
acceptors.append(build_acceptor(dag_run_state, state_waiter_category))
128+
129+
return acceptors

providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,16 @@ async def async_wait(
136136
last_response = error.last_response
137137

138138
if "terminal failure" in error_reason:
139-
log.error("%s: %s", failure_message, _LazyStatusFormatter(status_args, last_response))
140-
raise AirflowException(f"{failure_message}: {error}")
139+
raise AirflowException(
140+
f"{failure_message}: {_LazyStatusFormatter(status_args, last_response)}\n{error}"
141+
)
141142

142143
if (
143144
"An error occurred" in error_reason
144145
and isinstance(last_response.get("Error"), dict)
145146
and "Code" in last_response.get("Error")
146147
):
147-
raise AirflowException(f"{failure_message}: {error}")
148+
raise AirflowException(f"{failure_message}\n{last_response}\n{error}")
148149

149150
log.info("%s: %s", status_message, _LazyStatusFormatter(status_args, last_response))
150151
else:
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
{
2+
"version": 2,
3+
"waiters": {
4+
"mwaa_dag_run_complete": {
5+
"delay": 60,
6+
"maxAttempts": 720,
7+
"operation": "InvokeRestApi",
8+
"acceptors": [
9+
{
10+
"matcher": "path",
11+
"argument": "RestApiResponse.state",
12+
"expected": "queued",
13+
"state": "retry"
14+
},
15+
{
16+
"matcher": "path",
17+
"argument": "RestApiResponse.state",
18+
"expected": "running",
19+
"state": "retry"
20+
},
21+
{
22+
"matcher": "path",
23+
"argument": "RestApiResponse.state",
24+
"expected": "success",
25+
"state": "success"
26+
},
27+
{
28+
"matcher": "path",
29+
"argument": "RestApiResponse.state",
30+
"expected": "failed",
31+
"state": "failure"
32+
}
33+
]
34+
}
35+
}
36+
}

0 commit comments

Comments
 (0)