Skip to content

Commit 1370428

Browse files
authored
Add MwaaDagRunSensor to Amazon Provider Package (apache#46945)
Includes the doc page, unit tests and system test. Support for deferrable mode will be added soon.
1 parent af4cc3d commit 1370428

File tree

6 files changed

+231
-7
lines changed

6 files changed

+231
-7
lines changed

providers/amazon/docs/operators/mwaa.rst

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ is a managed service for Apache Airflow that lets you use your current, familiar
2424
your workflows. You gain improved scalability, availability, and security without the operational burden of managing
2525
underlying infrastructure.
2626

27+
Note: Unlike Airflow's built-in operators, these operators are meant for interaction with external Airflow environments
28+
hosted on AWS MWAA.
29+
2730
Prerequisite Tasks
2831
------------------
2932

@@ -45,19 +48,34 @@ Trigger a DAG run in an Amazon MWAA environment
4548
To trigger a DAG run in an Amazon MWAA environment you can use the
4649
:class:`~airflow.providers.amazon.aws.operators.mwaa.MwaaTriggerDagRunOperator`
4750

48-
Note: Unlike :class:`~airflow.providers.standard.operators.trigger_dagrun.TriggerDagRunOperator`, this operator is capable of
49-
triggering a DAG in a separate Airflow environment as long as the environment with the DAG being triggered is running on
50-
AWS MWAA.
51-
52-
In the following example, the task ``trigger_dag_run`` triggers a dag run for a DAG with with the ID ``hello_world`` in
53-
the environment ``MyAirflowEnvironment``.
51+
In the following example, the task ``trigger_dag_run`` triggers a DAG run for the DAG ``hello_world`` in the environment
52+
``MyAirflowEnvironment``.
5453

5554
.. exampleinclude:: /../../providers/amazon/tests/system/amazon/aws/example_mwaa.py
5655
:language: python
5756
:dedent: 4
5857
:start-after: [START howto_operator_mwaa_trigger_dag_run]
5958
:end-before: [END howto_operator_mwaa_trigger_dag_run]
6059

60+
Sensors
61+
-------
62+
63+
.. _howto/sensor:MwaaDagRunSensor:
64+
65+
Wait on the state of an AWS MWAA DAG Run
66+
========================================
67+
68+
To wait for a DAG Run running on Amazon MWAA until it reaches one of the given states, you can use the
69+
:class:`~airflow.providers.amazon.aws.sensors.mwaa.MwaaDagRunSensor`
70+
71+
In the following example, the task ``wait_for_dag_run`` waits for the DAG run created in the above task to complete.
72+
73+
.. exampleinclude:: /../../providers/amazon/tests/system/amazon/aws/example_mwaa.py
74+
:language: python
75+
:dedent: 4
76+
:start-after: [START howto_sensor_mwaa_dag_run]
77+
:end-before: [END howto_sensor_mwaa_dag_run]
78+
6179
References
6280
----------
6381

providers/amazon/provider.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,9 @@ sensors:
485485
- integration-name: Amazon Managed Service for Apache Flink
486486
python-modules:
487487
- airflow.providers.amazon.aws.sensors.kinesis_analytics
488+
- integration-name: Amazon Managed Workflows for Apache Airflow (MWAA)
489+
python-modules:
490+
- airflow.providers.amazon.aws.sensors.mwaa
488491
- integration-name: Amazon OpenSearch Serverless
489492
python-modules:
490493
- airflow.providers.amazon.aws.sensors.opensearch_serverless
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
#
2+
# Licensed to the Apache Software Foundation (ASF) under one
3+
# or more contributor license agreements. See the NOTICE file
4+
# distributed with this work for additional information
5+
# regarding copyright ownership. The ASF licenses this file
6+
# to you under the Apache License, Version 2.0 (the
7+
# "License"); you may not use this file except in compliance
8+
# with the License. You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing,
13+
# software distributed under the License is distributed on an
14+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
# KIND, either express or implied. See the License for the
16+
# specific language governing permissions and limitations
17+
# under the License.
18+
from __future__ import annotations
19+
20+
from collections.abc import Collection, Sequence
21+
from typing import TYPE_CHECKING
22+
23+
from airflow.exceptions import AirflowException
24+
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
25+
from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor
26+
from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
27+
from airflow.utils.state import State
28+
29+
if TYPE_CHECKING:
30+
from airflow.utils.context import Context
31+
32+
33+
class MwaaDagRunSensor(AwsBaseSensor[MwaaHook]):
34+
"""
35+
Waits for a DAG Run in an MWAA Environment to complete.
36+
37+
If the DAG Run fails, an AirflowException is thrown.
38+
39+
.. seealso::
40+
For more information on how to use this sensor, take a look at the guide:
41+
:ref:`howto/sensor:MwaaDagRunSensor`
42+
43+
:param external_env_name: The external MWAA environment name that contains the DAG Run you want to wait for
44+
(templated)
45+
:param external_dag_id: The DAG ID in the external MWAA environment that contains the DAG Run you want to wait for
46+
(templated)
47+
:param external_dag_run_id: The DAG Run ID in the external MWAA environment that you want to wait for (templated)
48+
:param success_states: Collection of DAG Run states that would make this task marked as successful, default is
49+
``airflow.utils.state.State.success_states`` (templated)
50+
:param failure_states: Collection of DAG Run states that would make this task marked as failed and raise an
51+
AirflowException, default is ``airflow.utils.state.State.failed_states`` (templated)
52+
"""
53+
54+
aws_hook_class = MwaaHook
55+
template_fields: Sequence[str] = aws_template_fields(
56+
"external_env_name",
57+
"external_dag_id",
58+
"external_dag_run_id",
59+
"success_states",
60+
"failure_states",
61+
)
62+
63+
def __init__(
64+
self,
65+
*,
66+
external_env_name: str,
67+
external_dag_id: str,
68+
external_dag_run_id: str,
69+
success_states: Collection[str] | None = None,
70+
failure_states: Collection[str] | None = None,
71+
**kwargs,
72+
):
73+
super().__init__(**kwargs)
74+
75+
self.success_states = set(success_states if success_states else State.success_states)
76+
self.failure_states = set(failure_states if failure_states else State.failed_states)
77+
78+
if len(self.success_states & self.failure_states):
79+
raise AirflowException("allowed_states and failed_states must not have any values in common")
80+
81+
self.external_env_name = external_env_name
82+
self.external_dag_id = external_dag_id
83+
self.external_dag_run_id = external_dag_run_id
84+
85+
def poke(self, context: Context) -> bool:
86+
self.log.info(
87+
"Poking for DAG run %s of DAG %s in MWAA environment %s",
88+
self.external_dag_run_id,
89+
self.external_dag_id,
90+
self.external_env_name,
91+
)
92+
response = self.hook.invoke_rest_api(
93+
env_name=self.external_env_name,
94+
path=f"/dags/{self.external_dag_id}/dagRuns/{self.external_dag_run_id}",
95+
method="GET",
96+
)
97+
98+
# If RestApiStatusCode == 200, the RestApiResponse must have the "state" key, otherwise something terrible has
99+
# happened in the API and KeyError would be raised
100+
# If RestApiStatusCode >= 300, a botocore exception would've already been raised during the
101+
# self.hook.invoke_rest_api call
102+
# The scope of this sensor is going to only be raising AirflowException due to failure of the DAGRun
103+
104+
state = response["RestApiResponse"]["state"]
105+
if state in self.success_states:
106+
return True
107+
108+
if state in self.failure_states:
109+
raise AirflowException(
110+
f"The DAG run {self.external_dag_run_id} of DAG {self.external_dag_id} in MWAA environment {self.external_env_name} "
111+
f"failed with state {state}."
112+
)
113+
return False

providers/amazon/src/airflow/providers/amazon/get_provider_info.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,10 @@ def get_provider_info():
604604
"integration-name": "Amazon Managed Service for Apache Flink",
605605
"python-modules": ["airflow.providers.amazon.aws.sensors.kinesis_analytics"],
606606
},
607+
{
608+
"integration-name": "Amazon Managed Workflows for Apache Airflow (MWAA)",
609+
"python-modules": ["airflow.providers.amazon.aws.sensors.mwaa"],
610+
},
607611
{
608612
"integration-name": "Amazon OpenSearch Serverless",
609613
"python-modules": ["airflow.providers.amazon.aws.sensors.opensearch_serverless"],

providers/amazon/tests/system/amazon/aws/example_mwaa.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from airflow.models.baseoperator import chain
2222
from airflow.models.dag import DAG
2323
from airflow.providers.amazon.aws.operators.mwaa import MwaaTriggerDagRunOperator
24+
from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor
2425
from system.amazon.aws.utils import SystemTestContextBuilder
2526

2627
DAG_ID = "example_mwaa"
@@ -29,7 +30,6 @@
2930
EXISTING_ENVIRONMENT_NAME_KEY = "ENVIRONMENT_NAME"
3031
EXISTING_DAG_ID_KEY = "DAG_ID"
3132

32-
3333
sys_test_context_task = (
3434
SystemTestContextBuilder()
3535
# NOTE: Creating a functional MWAA environment is time-consuming and requires
@@ -67,11 +67,22 @@
6767
)
6868
# [END howto_operator_mwaa_trigger_dag_run]
6969

70+
# [START howto_sensor_mwaa_dag_run]
71+
wait_for_dag_run = MwaaDagRunSensor(
72+
task_id="wait_for_dag_run",
73+
external_env_name=env_name,
74+
external_dag_id=trigger_dag_id,
75+
external_dag_run_id="{{ task_instance.xcom_pull(task_ids='trigger_dag_run')['RestApiResponse']['dag_run_id'] }}",
76+
poke_interval=5,
77+
)
78+
# [END howto_sensor_mwaa_dag_run]
79+
7080
chain(
7181
# TEST SETUP
7282
test_context,
7383
# TEST BODY
7484
trigger_dag_run,
85+
wait_for_dag_run,
7586
)
7687

7788
from tests_common.test_utils.watcher import watcher
Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
from __future__ import annotations
18+
19+
from unittest import mock
20+
21+
import pytest
22+
23+
from airflow.exceptions import AirflowException
24+
from airflow.providers.amazon.aws.hooks.mwaa import MwaaHook
25+
from airflow.providers.amazon.aws.sensors.mwaa import MwaaDagRunSensor
26+
from airflow.utils.state import State
27+
28+
SENSOR_KWARGS = {
29+
"task_id": "test_mwaa_sensor",
30+
"external_env_name": "test_env",
31+
"external_dag_id": "test_dag",
32+
"external_dag_run_id": "test_run_id",
33+
}
34+
35+
36+
@pytest.fixture
37+
def mock_invoke_rest_api():
38+
with mock.patch.object(MwaaHook, "invoke_rest_api") as m:
39+
yield m
40+
41+
42+
class TestMwaaDagRunSuccessSensor:
43+
def test_init_success(self):
44+
success_states = {"state1", "state2"}
45+
failure_states = {"state3", "state4"}
46+
sensor = MwaaDagRunSensor(
47+
**SENSOR_KWARGS, success_states=success_states, failure_states=failure_states
48+
)
49+
assert sensor.external_env_name == SENSOR_KWARGS["external_env_name"]
50+
assert sensor.external_dag_id == SENSOR_KWARGS["external_dag_id"]
51+
assert sensor.external_dag_run_id == SENSOR_KWARGS["external_dag_run_id"]
52+
assert set(sensor.success_states) == success_states
53+
assert set(sensor.failure_states) == failure_states
54+
55+
def test_init_failure(self):
56+
with pytest.raises(AirflowException):
57+
MwaaDagRunSensor(
58+
**SENSOR_KWARGS, success_states={"state1", "state2"}, failure_states={"state2", "state3"}
59+
)
60+
61+
@pytest.mark.parametrize("status", sorted(State.success_states))
62+
def test_poke_completed(self, mock_invoke_rest_api, status):
63+
mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": status}}
64+
assert MwaaDagRunSensor(**SENSOR_KWARGS).poke({})
65+
66+
@pytest.mark.parametrize("status", ["running", "queued"])
67+
def test_poke_not_completed(self, mock_invoke_rest_api, status):
68+
mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": status}}
69+
assert not MwaaDagRunSensor(**SENSOR_KWARGS).poke({})
70+
71+
@pytest.mark.parametrize("status", sorted(State.failed_states))
72+
def test_poke_terminated(self, mock_invoke_rest_api, status):
73+
mock_invoke_rest_api.return_value = {"RestApiResponse": {"state": status}}
74+
with pytest.raises(AirflowException):
75+
MwaaDagRunSensor(**SENSOR_KWARGS).poke({})

0 commit comments

Comments
 (0)