1919from __future__ import annotations
2020
2121from collections .abc import Sequence
22- from typing import TYPE_CHECKING
22+ from typing import TYPE_CHECKING , Any
2323
24+ from airflow .configuration import conf
25+ from airflow .exceptions import AirflowException
2426from airflow .providers .amazon .aws .hooks .mwaa import MwaaHook
2527from airflow .providers .amazon .aws .operators .base_aws import AwsBaseOperator
28+ from airflow .providers .amazon .aws .triggers .mwaa import MwaaDagRunCompletedTrigger
29+ from airflow .providers .amazon .aws .utils import validate_execute_complete_event
2630from airflow .providers .amazon .aws .utils .mixins import aws_template_fields
2731
2832if TYPE_CHECKING :
@@ -48,6 +52,23 @@ class MwaaTriggerDagRunOperator(AwsBaseOperator[MwaaHook]):
4852 :param conf: Additional configuration parameters. The value of this field can be set only when creating
4953 the object. (templated)
5054 :param note: Contains manually entered notes by the user about the DagRun. (templated)
55+
56+ :param wait_for_completion: Whether to wait for DAG run to stop. (default: False)
57+ :param waiter_delay: Time in seconds to wait between status checks. (default: 120)
58+ :param waiter_max_attempts: Maximum number of attempts to check for DAG run completion. (default: 720)
59+ :param deferrable: If True, the operator will wait asynchronously for the DAG run to stop.
60+ This implies waiting for completion. This mode requires aiobotocore module to be installed.
61+ (default: False)
62+ :param aws_conn_id: The Airflow connection used for AWS credentials.
63+ If this is ``None`` or empty then the default boto3 behaviour is used. If
64+ running Airflow in a distributed manner and aws_conn_id is None or
65+ empty, then default boto3 configuration would be used (and must be
66+ maintained on each worker node).
67+ :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used.
68+ :param verify: Whether or not to verify SSL certificates. See:
69+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
70+ :param botocore_config: Configuration dictionary (key-values) for botocore client. See:
71+ https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
5172 """
5273
5374 aws_hook_class = MwaaHook
@@ -74,6 +95,10 @@ def __init__(
7495 data_interval_end : str | None = None ,
7596 conf : dict | None = None ,
7697 note : str | None = None ,
98+ wait_for_completion : bool = False ,
99+ waiter_delay : int = 60 ,
100+ waiter_max_attempts : int = 720 ,
101+ deferrable : bool = conf .getboolean ("operators" , "default_deferrable" , fallback = False ),
77102 ** kwargs ,
78103 ):
79104 super ().__init__ (** kwargs )
@@ -85,6 +110,21 @@ def __init__(
85110 self .data_interval_end = data_interval_end
86111 self .conf = conf if conf else {}
87112 self .note = note
113+ self .wait_for_completion = wait_for_completion
114+ self .waiter_delay = waiter_delay
115+ self .waiter_max_attempts = waiter_max_attempts
116+ self .deferrable = deferrable
117+
118+ def execute_complete (self , context : Context , event : dict [str , Any ] | None = None ) -> dict :
119+ validated_event = validate_execute_complete_event (event )
120+ if validated_event ["status" ] != "success" :
121+ raise AirflowException (f"DAG run failed: { validated_event } " )
122+
123+ dag_run_id = validated_event ["dag_run_id" ]
124+ self .log .info ("DAG run %s of DAG %s completed" , dag_run_id , self .trigger_dag_id )
125+ return self .hook .invoke_rest_api (
126+ env_name = self .env_name , path = f"/dags/{ self .trigger_dag_id } /dagRuns/{ dag_run_id } " , method = "GET"
127+ )
88128
89129 def execute (self , context : Context ) -> dict :
90130 """
@@ -94,7 +134,7 @@ def execute(self, context: Context) -> dict:
94134 :return: dict with information about the Dag run
95135 For details of the returned dict, see :py:meth:`botocore.client.MWAA.invoke_rest_api`
96136 """
97- return self .hook .invoke_rest_api (
137+ response = self .hook .invoke_rest_api (
98138 env_name = self .env_name ,
99139 path = f"/dags/{ self .trigger_dag_id } /dagRuns" ,
100140 method = "POST" ,
@@ -107,3 +147,34 @@ def execute(self, context: Context) -> dict:
107147 "note" : self .note ,
108148 },
109149 )
150+
151+ dag_run_id = response ["RestApiResponse" ]["dag_run_id" ]
152+ self .log .info ("DAG run %s of DAG %s created" , dag_run_id , self .trigger_dag_id )
153+
154+ task_description = f"DAG run { dag_run_id } of DAG { self .trigger_dag_id } to complete"
155+ if self .deferrable :
156+ self .log .info ("Deferring for %s" , task_description )
157+ self .defer (
158+ trigger = MwaaDagRunCompletedTrigger (
159+ external_env_name = self .env_name ,
160+ external_dag_id = self .trigger_dag_id ,
161+ external_dag_run_id = dag_run_id ,
162+ waiter_delay = self .waiter_delay ,
163+ waiter_max_attempts = self .waiter_max_attempts ,
164+ aws_conn_id = self .aws_conn_id ,
165+ ),
166+ method_name = "execute_complete" ,
167+ )
168+ elif self .wait_for_completion :
169+ self .log .info ("Waiting for %s" , task_description )
170+ api_kwargs = {
171+ "Name" : self .env_name ,
172+ "Path" : f"/dags/{ self .trigger_dag_id } /dagRuns/{ dag_run_id } " ,
173+ "Method" : "GET" ,
174+ }
175+ self .hook .get_waiter ("mwaa_dag_run_complete" ).wait (
176+ ** api_kwargs ,
177+ WaiterConfig = {"Delay" : self .waiter_delay , "MaxAttempts" : self .waiter_max_attempts },
178+ )
179+
180+ return response
0 commit comments