diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py new file mode 100644 index 0000000000000..a00225fea0bbc --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/dags.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from airflow.api_fastapi.core_api.base import BaseModel + + +class DagStateResponse(BaseModel): + """Schema for DAG State response.""" + + is_paused: bool diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index 562b8588fbf2c..2852ba2378fdc 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -25,6 +25,7 @@ assets, connections, dag_runs, + dags, health, hitl, task_instances, @@ -43,6 +44,7 @@ authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"]) authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"]) authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"]) +authenticated_router.include_router(dags.router, prefix="/dags", tags=["Dags"]) authenticated_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"]) authenticated_router.include_router( task_reschedules.router, prefix="/task-reschedules", tags=["Task Reschedules"] diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py new file mode 100644 index 0000000000000..9b10393217c41 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/dags.py @@ -0,0 +1,57 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import logging + +from fastapi import APIRouter, HTTPException, status + +from airflow.api_fastapi.common.db.common import SessionDep +from airflow.api_fastapi.execution_api.datamodels.dags import DagStateResponse +from airflow.models.dag import DagModel + +router = APIRouter() + + +log = logging.getLogger(__name__) + + +@router.get( + "/{dag_id}/state", + responses={ + status.HTTP_404_NOT_FOUND: {"description": "DAG not found for the given dag_id"}, + }, +) +def get_dag_state( + dag_id: str, + session: SessionDep, +) -> DagStateResponse: + """Get a DAG Run State.""" + dag_model: DagModel = session.get(DagModel, dag_id) + if not dag_model: + raise HTTPException( + status.HTTP_404_NOT_FOUND, + detail={ + "reason": "not_found", + "message": f"The Dag with dag_id: `{dag_id}` was not found", + }, + ) + + is_paused = False if dag_model.is_paused is None else dag_model.is_paused + + return DagStateResponse(is_paused=is_paused) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py new file mode 100644 index 0000000000000..e03760697ac21 --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_dags.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import pytest + +from airflow.models import DagModel +from airflow.providers.standard.operators.empty import EmptyOperator + +from tests_common.test_utils.db import clear_db_runs + +pytestmark = pytest.mark.db_test + + +class TestDagState: + def setup_method(self): + clear_db_runs() + + def teardown_method(self): + clear_db_runs() + + @pytest.mark.parametrize( + ("state", "expected"), + [ + pytest.param(True, True), + pytest.param(False, False), + pytest.param(None, False), + ], + ) + def test_dag_is_paused(self, client, session, dag_maker, state, expected): + """Test DagState is active or paused""" + + dag_id = "test_dag_is_paused" + + with dag_maker(dag_id=dag_id, session=session, serialized=True): + EmptyOperator(task_id="test_task") + + session.query(DagModel).filter(DagModel.dag_id == dag_id).update({"is_paused": state}) + + session.commit() + + response = client.get( + f"/execution/dags/{dag_id}/state", + ) + + assert response.status_code == 200 + assert response.json() == {"is_paused": expected} + + def test_dag_not_found(self, client, session, dag_maker): + """Test Dag not found""" + + dag_id = "test_dag_is_paused" + + response = client.get( + f"/execution/dags/{dag_id}/state", + ) + + assert response.status_code == 404 + assert response.json() == { + "detail": { + "message": "The Dag with dag_id: `test_dag_is_paused` was not found", + "reason": "not_found", + } + } diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 1348a428d5fcb..063fc0dc99416 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -1847,6 +1847,7 @@ def get_type_names(union_type): "GetAssetEventByAssetAlias", "GetDagRun", "GetDagRunState", + "GetDagState", "GetDRCount", "GetTaskBreadcrumbs", "GetTaskRescheduleStartDate", diff --git a/airflow-core/tests/unit/jobs/test_triggerer_job.py b/airflow-core/tests/unit/jobs/test_triggerer_job.py index df31c9225c272..5483dfebfbd99 100644 --- a/airflow-core/tests/unit/jobs/test_triggerer_job.py +++ b/airflow-core/tests/unit/jobs/test_triggerer_job.py @@ -1219,6 +1219,7 @@ def get_type_names(union_type): "ResendLoggingFD", "CreateHITLDetailPayload", "SetRenderedMapIndex", + "GetDagState", } in_task_but_not_in_trigger_runner = { @@ -1238,6 +1239,7 @@ def get_type_names(union_type): "PreviousDagRunResult", "PreviousTIResult", "HITLDetailRequestResult", + "DagStateResult", } supervisor_diff = ( diff --git a/task-sdk/src/airflow/sdk/api/client.py b/task-sdk/src/airflow/sdk/api/client.py index 2a38ef2fad7b3..3c308bf959c67 100644 --- a/task-sdk/src/airflow/sdk/api/client.py +++ b/task-sdk/src/airflow/sdk/api/client.py @@ -48,6 +48,7 @@ DagRun, DagRunStateResponse, DagRunType, + DagStateResponse, HITLDetailRequest, HITLDetailResponse, HITLUser, @@ -769,6 +770,18 @@ def get_previous( return PreviousDagRunResult(dag_run=resp.json()) +class DagsOperations: + __slots__ = ("client",) + + def __init__(self, client: Client): + self.client = client + + def get_state(self, dag_id: str) -> DagStateResponse: + """Get the state of a Dag via the API server.""" + resp = self.client.get(f"dags/{dag_id}/state") + return DagStateResponse.model_validate_json(resp.read()) + + class HITLOperations: """ Operations related to Human in the loop. Require Airflow 3.1+. @@ -1001,6 +1014,11 @@ def hitl(self): """Operations related to HITL Responses.""" return HITLOperations(self) + @lru_cache() # type: ignore[misc] + @property + def dags(self): + return DagsOperations(self) + # This is only used for parsing. ServerResponseError is raised instead class _ErrorBody(BaseModel): diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index 6a3a07f5e8a92..c9fae7f8f33e8 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -130,6 +130,14 @@ class DagRunType(str, Enum): ASSET_TRIGGERED = "asset_triggered" +class DagStateResponse(BaseModel): + """ + Schema for DAG State response. + """ + + is_paused: Annotated[bool, Field(title="Is Paused")] + + class HITLUser(BaseModel): """ Schema for a Human-in-the-loop users. diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py b/task-sdk/src/airflow/sdk/execution_time/comms.py index 52a96d0b665ea..ac6a0f7e06b9a 100644 --- a/task-sdk/src/airflow/sdk/execution_time/comms.py +++ b/task-sdk/src/airflow/sdk/execution_time/comms.py @@ -654,6 +654,11 @@ def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequest return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult") +class DagStateResult(BaseModel): + is_paused: bool + type: Literal["DagStateResult"] = "DagStateResult" + + ToTask = Annotated[ AssetResult | AssetEventsResult @@ -661,6 +666,7 @@ def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequest | DagRunResult | DagRunStateResult | DRCount + | DagStateResult | ErrorResponse | PrevSuccessfulDagRunResult | PreviousTIResult @@ -978,6 +984,11 @@ class MaskSecret(BaseModel): type: Literal["MaskSecret"] = "MaskSecret" +class GetDagState(BaseModel): + dag_id: str + type: Literal["GetDagState"] = "GetDagState" + + ToSupervisor = Annotated[ DeferTask | DeleteXCom @@ -989,6 +1000,7 @@ class MaskSecret(BaseModel): | GetDagRun | GetDagRunState | GetDRCount + | GetDagState | GetPrevSuccessfulDagRun | GetPreviousDagRun | GetPreviousTI diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py b/task-sdk/src/airflow/sdk/execution_time/supervisor.py index b87131aa7336d..1bcdae2ee79a0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py +++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py @@ -79,6 +79,7 @@ GetConnection, GetDagRun, GetDagRunState, + GetDagState, GetDRCount, GetPreviousDagRun, GetPreviousTI, @@ -1459,6 +1460,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id: dump_opts = {"exclude_unset": True} elif isinstance(msg, MaskSecret): mask_secret(msg.value, msg.name) + elif isinstance(msg, GetDagState): + resp = self.client.dags.get_state( + dag_id=msg.dag_id, + ) else: log.error("Unhandled request", msg=msg) self.send_msg( diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py b/task-sdk/src/airflow/sdk/execution_time/task_runner.py index cb76097b9fefc..2778b6ef9f2d0 100644 --- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py +++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py @@ -70,10 +70,12 @@ AssetEventDagRunReferenceResult, CommsDecoder, DagRunStateResult, + DagStateResult, DeferTask, DRCount, ErrorResponse, GetDagRunState, + GetDagState, GetDRCount, GetPreviousDagRun, GetPreviousTI, @@ -620,6 +622,16 @@ def get_dagrun_state(dag_id: str, run_id: str) -> str: return response.state + @staticmethod + def get_dag_state(dag_id: str) -> DagStateResult: + """Return the state of the Dag run with the given Run ID.""" + response = SUPERVISOR_COMMS.send(msg=GetDagState(dag_id=dag_id)) + + if TYPE_CHECKING: + assert isinstance(response, DagStateResult) + + return response + @property def log_url(self) -> str: run_id = quote(self.run_id) diff --git a/task-sdk/tests/task_sdk/api/test_client.py b/task-sdk/tests/task_sdk/api/test_client.py index ea9d403fe21ad..df98fd94b8106 100644 --- a/task-sdk/tests/task_sdk/api/test_client.py +++ b/task-sdk/tests/task_sdk/api/test_client.py @@ -39,6 +39,7 @@ ConnectionResponse, DagRunState, DagRunStateResponse, + DagStateResponse, HITLDetailRequest, HITLDetailResponse, HITLUser, @@ -1536,3 +1537,21 @@ def test_cache_miss_on_different_parameters(self): assert ctx1 is not ctx2 assert info.misses == 2 assert info.currsize == 2 + + +class TestDagsOperations: + def test_get_state(self): + """Test that the client can get the state of a dag run""" + + def handle_request(request: httpx.Request) -> httpx.Response: + if request.url.path == "/dags/test_dag/state": + return httpx.Response( + status_code=200, + json={"is_paused": False}, + ) + return httpx.Response(status_code=200) + + client = make_client(transport=httpx.MockTransport(handle_request)) + result = client.dags.get_state(dag_id="test_dag") + + assert result == DagStateResponse(is_paused=False) diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py index b45f57cd4256e..54303327a53eb 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py +++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py @@ -73,6 +73,7 @@ CreateHITLDetailPayload, DagRunResult, DagRunStateResult, + DagStateResult, DeferTask, DeleteVariable, DeleteXCom, @@ -85,6 +86,7 @@ GetConnection, GetDagRun, GetDagRunState, + GetDagState, GetDRCount, GetHITLDetailResponse, GetPreviousDagRun, @@ -2461,6 +2463,18 @@ class RequestTestCase: }, test_id="get_task_breadcrumbs", ), + RequestTestCase( + message=GetDagState(dag_id="test_dag"), + expected_body={"is_paused": False, "type": "DagStateResult"}, + client_mock=ClientMock( + method_path="dags.get_state", + kwargs={ + "dag_id": "test_dag", + }, + response=DagStateResult(is_paused=False), + ), + test_id="get_dag_state", + ), ] diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py index e6a179f7189c5..bdfa4a6363ecb 100644 --- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py +++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py @@ -78,11 +78,13 @@ BundleInfo, ConnectionResult, DagRunStateResult, + DagStateResult, DeferTask, DRCount, ErrorResponse, GetConnection, GetDagRunState, + GetDagState, GetDRCount, GetPreviousDagRun, GetPreviousTI, @@ -2524,6 +2526,19 @@ def execute(self, context): in mock_supervisor_comms.send.mock_calls ) + def test_get_dag_state(self, mock_supervisor_comms): + """Test that get_dag_state sends the correct request and returns the state.""" + mock_supervisor_comms.send.return_value = DagStateResult(is_paused=False) + + response = RuntimeTaskInstance.get_dag_state( + dag_id="test_dag", + ) + + mock_supervisor_comms.send.assert_called_once_with( + msg=GetDagState(dag_id="test_dag"), + ) + assert response.is_paused is False + class TestXComAfterTaskExecution: @pytest.mark.parametrize(