Skip to content

Commit fcc28a5

Browse files
authored
Add use_rest_api parameter for CloudComposerDAGRunSensor for pulling dag_runs using the Airflow REST API (apache#56138)
1 parent 5cb96f6 commit fcc28a5

File tree

6 files changed

+344
-61
lines changed

6 files changed

+344
-61
lines changed

providers/google/src/airflow/providers/google/cloud/hooks/cloud_composer.py

Lines changed: 131 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
from typing import TYPE_CHECKING, Any
2525
from urllib.parse import urljoin
2626

27+
from aiohttp import ClientSession
2728
from google.api_core.client_options import ClientOptions
2829
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
29-
from google.auth.transport.requests import AuthorizedSession
30+
from google.auth.transport.requests import AuthorizedSession, Request
3031
from google.cloud.orchestration.airflow.service_v1 import (
3132
EnvironmentsAsyncClient,
3233
EnvironmentsClient,
@@ -472,6 +473,38 @@ def trigger_dag_run(
472473

473474
return response.json()
474475

476+
def get_dag_runs(
477+
self,
478+
composer_airflow_uri: str,
479+
composer_dag_id: str,
480+
timeout: float | None = None,
481+
) -> dict:
482+
"""
483+
Get the list of dag runs for provided DAG.
484+
485+
:param composer_airflow_uri: The URI of the Apache Airflow Web UI hosted within Composer environment.
486+
:param composer_dag_id: The ID of DAG.
487+
:param timeout: The timeout for this request.
488+
"""
489+
response = self.make_composer_airflow_api_request(
490+
method="GET",
491+
airflow_uri=composer_airflow_uri,
492+
path=f"/api/v1/dags/{composer_dag_id}/dagRuns",
493+
timeout=timeout,
494+
)
495+
496+
if response.status_code != 200:
497+
self.log.error(
498+
"Failed to get DAG runs for dag_id=%s from %s (status=%s): %s",
499+
composer_dag_id,
500+
composer_airflow_uri,
501+
response.status_code,
502+
response.text,
503+
)
504+
response.raise_for_status()
505+
506+
return response.json()
507+
475508

476509
class CloudComposerAsyncHook(GoogleBaseAsyncHook):
477510
"""Hook for Google Cloud Composer async APIs."""
@@ -489,6 +522,42 @@ async def get_environment_client(self) -> EnvironmentsAsyncClient:
489522
client_options=self.client_options,
490523
)
491524

525+
async def make_composer_airflow_api_request(
526+
self,
527+
method: str,
528+
airflow_uri: str,
529+
path: str,
530+
data: Any | None = None,
531+
timeout: float | None = None,
532+
):
533+
"""
534+
Make a request to Cloud Composer environment's web server.
535+
536+
:param method: The request method to use ('GET', 'OPTIONS', 'HEAD', 'POST', 'PUT', 'PATCH', 'DELETE').
537+
:param airflow_uri: The URI of the Apache Airflow Web UI hosted within this environment.
538+
:param path: The path to send the request.
539+
:param data: Dictionary, list of tuples, bytes, or file-like object to send in the body of the request.
540+
:param timeout: The timeout for this request.
541+
"""
542+
sync_hook = await self.get_sync_hook()
543+
credentials = sync_hook.get_credentials()
544+
545+
if not credentials.valid:
546+
credentials.refresh(Request())
547+
548+
async with ClientSession() as session:
549+
async with session.request(
550+
method=method,
551+
url=urljoin(airflow_uri, path),
552+
data=data,
553+
headers={
554+
"Content-Type": "application/json",
555+
"Authorization": f"Bearer {credentials.token}",
556+
},
557+
timeout=timeout,
558+
) as response:
559+
return await response.json(), response.status
560+
492561
def get_environment_name(self, project_id, region, environment_id):
493562
return f"projects/{project_id}/locations/{region}/environments/{environment_id}"
494563

@@ -594,6 +663,35 @@ async def update_environment(
594663
metadata=metadata,
595664
)
596665

666+
@GoogleBaseHook.fallback_to_default_project_id
667+
async def get_environment(
668+
self,
669+
project_id: str,
670+
region: str,
671+
environment_id: str,
672+
retry: AsyncRetry | _MethodDefault = DEFAULT,
673+
timeout: float | None = None,
674+
metadata: Sequence[tuple[str, str]] = (),
675+
) -> Environment:
676+
"""
677+
Get an existing environment.
678+
679+
:param project_id: Required. The ID of the Google Cloud project that the service belongs to.
680+
:param region: Required. The ID of the Google Cloud region that the service belongs to.
681+
:param environment_id: Required. The ID of the Google Cloud environment that the service belongs to.
682+
:param retry: Designation of what errors, if any, should be retried.
683+
:param timeout: The timeout for this request.
684+
:param metadata: Strings which should be sent along with the request as metadata.
685+
"""
686+
client = await self.get_environment_client()
687+
688+
return await client.get_environment(
689+
request={"name": self.get_environment_name(project_id, region, environment_id)},
690+
retry=retry,
691+
timeout=timeout,
692+
metadata=metadata,
693+
)
694+
597695
@GoogleBaseHook.fallback_to_default_project_id
598696
async def execute_airflow_command(
599697
self,
@@ -719,3 +817,35 @@ async def wait_command_execution_result(
719817

720818
self.log.info("Sleeping for %s seconds.", poll_interval)
721819
await asyncio.sleep(poll_interval)
820+
821+
async def get_dag_runs(
822+
self,
823+
composer_airflow_uri: str,
824+
composer_dag_id: str,
825+
timeout: float | None = None,
826+
) -> dict:
827+
"""
828+
Get the list of dag runs for provided DAG.
829+
830+
:param composer_airflow_uri: The URI of the Apache Airflow Web UI hosted within Composer environment.
831+
:param composer_dag_id: The ID of DAG.
832+
:param timeout: The timeout for this request.
833+
"""
834+
response_body, response_status_code = await self.make_composer_airflow_api_request(
835+
method="GET",
836+
airflow_uri=composer_airflow_uri,
837+
path=f"/api/v1/dags/{composer_dag_id}/dagRuns",
838+
timeout=timeout,
839+
)
840+
841+
if response_status_code != 200:
842+
self.log.error(
843+
"Failed to get DAG runs for dag_id=%s from %s (status=%s): %s",
844+
composer_dag_id,
845+
composer_airflow_uri,
846+
response_status_code,
847+
response_body["title"],
848+
)
849+
raise AirflowException(response_body["title"])
850+
851+
return response_body

providers/google/src/airflow/providers/google/cloud/sensors/cloud_composer.py

Lines changed: 53 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from typing import TYPE_CHECKING
2727

2828
from dateutil import parser
29+
from google.api_core.exceptions import NotFound
2930
from google.cloud.orchestration.airflow.service_v1.types import Environment, ExecuteAirflowCommandResponse
3031

3132
from airflow.configuration import conf
@@ -97,6 +98,7 @@ def __init__(
9798
impersonation_chain: str | Sequence[str] | None = None,
9899
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
99100
poll_interval: int = 10,
101+
use_rest_api: bool = False,
100102
**kwargs,
101103
) -> None:
102104
super().__init__(**kwargs)
@@ -111,6 +113,7 @@ def __init__(
111113
self.impersonation_chain = impersonation_chain
112114
self.deferrable = deferrable
113115
self.poll_interval = poll_interval
116+
self.use_rest_api = use_rest_api
114117

115118
if self.composer_dag_run_id and self.execution_range:
116119
self.log.warning(
@@ -161,26 +164,51 @@ def poke(self, context: Context) -> bool:
161164

162165
def _pull_dag_runs(self) -> list[dict]:
163166
"""Pull the list of dag runs."""
164-
cmd_parameters = (
165-
["-d", self.composer_dag_id, "-o", "json"]
166-
if self._composer_airflow_version < 3
167-
else [self.composer_dag_id, "-o", "json"]
168-
)
169-
dag_runs_cmd = self.hook.execute_airflow_command(
170-
project_id=self.project_id,
171-
region=self.region,
172-
environment_id=self.environment_id,
173-
command="dags",
174-
subcommand="list-runs",
175-
parameters=cmd_parameters,
176-
)
177-
cmd_result = self.hook.wait_command_execution_result(
178-
project_id=self.project_id,
179-
region=self.region,
180-
environment_id=self.environment_id,
181-
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
182-
)
183-
dag_runs = json.loads(cmd_result["output"][0]["content"])
167+
if self.use_rest_api:
168+
try:
169+
environment = self.hook.get_environment(
170+
project_id=self.project_id,
171+
region=self.region,
172+
environment_id=self.environment_id,
173+
timeout=self.timeout,
174+
)
175+
except NotFound as not_found_err:
176+
self.log.info("The Composer environment %s does not exist.", self.environment_id)
177+
raise AirflowException(not_found_err)
178+
composer_airflow_uri = environment.config.airflow_uri
179+
180+
self.log.info(
181+
"Pulling the DAG %s runs from the %s environment...",
182+
self.composer_dag_id,
183+
self.environment_id,
184+
)
185+
dag_runs_response = self.hook.get_dag_runs(
186+
composer_airflow_uri=composer_airflow_uri,
187+
composer_dag_id=self.composer_dag_id,
188+
timeout=self.timeout,
189+
)
190+
dag_runs = dag_runs_response["dag_runs"]
191+
else:
192+
cmd_parameters = (
193+
["-d", self.composer_dag_id, "-o", "json"]
194+
if self._composer_airflow_version < 3
195+
else [self.composer_dag_id, "-o", "json"]
196+
)
197+
dag_runs_cmd = self.hook.execute_airflow_command(
198+
project_id=self.project_id,
199+
region=self.region,
200+
environment_id=self.environment_id,
201+
command="dags",
202+
subcommand="list-runs",
203+
parameters=cmd_parameters,
204+
)
205+
cmd_result = self.hook.wait_command_execution_result(
206+
project_id=self.project_id,
207+
region=self.region,
208+
environment_id=self.environment_id,
209+
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
210+
)
211+
dag_runs = json.loads(cmd_result["output"][0]["content"])
184212
return dag_runs
185213

186214
def _check_dag_runs_states(
@@ -213,7 +241,10 @@ def _get_composer_airflow_version(self) -> int:
213241

214242
def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
215243
for dag_run in dag_runs:
216-
if dag_run["run_id"] == self.composer_dag_run_id and dag_run["state"] in self.allowed_states:
244+
if (
245+
dag_run["dag_run_id" if self.use_rest_api else "run_id"] == self.composer_dag_run_id
246+
and dag_run["state"] in self.allowed_states
247+
):
217248
return True
218249
return False
219250

@@ -236,6 +267,7 @@ def execute(self, context: Context) -> None:
236267
impersonation_chain=self.impersonation_chain,
237268
poll_interval=self.poll_interval,
238269
composer_airflow_version=self._composer_airflow_version,
270+
use_rest_api=self.use_rest_api,
239271
),
240272
method_name=GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME,
241273
)

providers/google/src/airflow/providers/google/cloud/triggers/cloud_composer.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from typing import Any
2626

2727
from dateutil import parser
28+
from google.api_core.exceptions import NotFound
2829
from google.cloud.orchestration.airflow.service_v1.types import ExecuteAirflowCommandResponse
2930

3031
from airflow.exceptions import AirflowException
@@ -188,6 +189,7 @@ def __init__(
188189
impersonation_chain: str | Sequence[str] | None = None,
189190
poll_interval: int = 10,
190191
composer_airflow_version: int = 2,
192+
use_rest_api: bool = False,
191193
):
192194
super().__init__()
193195
self.project_id = project_id
@@ -202,6 +204,7 @@ def __init__(
202204
self.impersonation_chain = impersonation_chain
203205
self.poll_interval = poll_interval
204206
self.composer_airflow_version = composer_airflow_version
207+
self.use_rest_api = use_rest_api
205208

206209
def serialize(self) -> tuple[str, dict[str, Any]]:
207210
return (
@@ -219,31 +222,55 @@ def serialize(self) -> tuple[str, dict[str, Any]]:
219222
"impersonation_chain": self.impersonation_chain,
220223
"poll_interval": self.poll_interval,
221224
"composer_airflow_version": self.composer_airflow_version,
225+
"use_rest_api": self.use_rest_api,
222226
},
223227
)
224228

225229
async def _pull_dag_runs(self) -> list[dict]:
226230
"""Pull the list of dag runs."""
227-
cmd_parameters = (
228-
["-d", self.composer_dag_id, "-o", "json"]
229-
if self.composer_airflow_version < 3
230-
else [self.composer_dag_id, "-o", "json"]
231-
)
232-
dag_runs_cmd = await self.gcp_hook.execute_airflow_command(
233-
project_id=self.project_id,
234-
region=self.region,
235-
environment_id=self.environment_id,
236-
command="dags",
237-
subcommand="list-runs",
238-
parameters=cmd_parameters,
239-
)
240-
cmd_result = await self.gcp_hook.wait_command_execution_result(
241-
project_id=self.project_id,
242-
region=self.region,
243-
environment_id=self.environment_id,
244-
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
245-
)
246-
dag_runs = json.loads(cmd_result["output"][0]["content"])
231+
if self.use_rest_api:
232+
try:
233+
environment = await self.gcp_hook.get_environment(
234+
project_id=self.project_id,
235+
region=self.region,
236+
environment_id=self.environment_id,
237+
)
238+
except NotFound as not_found_err:
239+
self.log.info("The Composer environment %s does not exist.", self.environment_id)
240+
raise AirflowException(not_found_err)
241+
composer_airflow_uri = environment.config.airflow_uri
242+
243+
self.log.info(
244+
"Pulling the DAG %s runs from the %s environment...",
245+
self.composer_dag_id,
246+
self.environment_id,
247+
)
248+
dag_runs_response = await self.gcp_hook.get_dag_runs(
249+
composer_airflow_uri=composer_airflow_uri,
250+
composer_dag_id=self.composer_dag_id,
251+
)
252+
dag_runs = dag_runs_response["dag_runs"]
253+
else:
254+
cmd_parameters = (
255+
["-d", self.composer_dag_id, "-o", "json"]
256+
if self.composer_airflow_version < 3
257+
else [self.composer_dag_id, "-o", "json"]
258+
)
259+
dag_runs_cmd = await self.gcp_hook.execute_airflow_command(
260+
project_id=self.project_id,
261+
region=self.region,
262+
environment_id=self.environment_id,
263+
command="dags",
264+
subcommand="list-runs",
265+
parameters=cmd_parameters,
266+
)
267+
cmd_result = await self.gcp_hook.wait_command_execution_result(
268+
project_id=self.project_id,
269+
region=self.region,
270+
environment_id=self.environment_id,
271+
execution_cmd_info=ExecuteAirflowCommandResponse.to_dict(dag_runs_cmd),
272+
)
273+
dag_runs = json.loads(cmd_result["output"][0]["content"])
247274
return dag_runs
248275

249276
def _check_dag_runs_states(
@@ -271,7 +298,10 @@ def _get_async_hook(self) -> CloudComposerAsyncHook:
271298

272299
def _check_composer_dag_run_id_states(self, dag_runs: list[dict]) -> bool:
273300
for dag_run in dag_runs:
274-
if dag_run["run_id"] == self.composer_dag_run_id and dag_run["state"] in self.allowed_states:
301+
if (
302+
dag_run["dag_run_id" if self.use_rest_api else "run_id"] == self.composer_dag_run_id
303+
and dag_run["state"] in self.allowed_states
304+
):
275305
return True
276306
return False
277307

0 commit comments

Comments
 (0)