Skip to content

Commit 1691bc8

Browse files
committed
Add GetDagState endpoint to execution_api
1 parent 209fb2e commit 1691bc8

File tree

11 files changed

+253
-0
lines changed

11 files changed

+253
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
from airflow.api_fastapi.core_api.base import BaseModel
21+
22+
23+
class DagStateResponse(BaseModel):
24+
"""Schema for DAG State response."""
25+
26+
is_paused: bool

airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
assets,
2626
connections,
2727
dag_runs,
28+
dags,
2829
health,
2930
hitl,
3031
task_instances,
@@ -43,6 +44,7 @@
4344
authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"])
4445
authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
4546
authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"])
47+
authenticated_router.include_router(dags.router, prefix="/dags", tags=["Dags"])
4648
authenticated_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"])
4749
authenticated_router.include_router(
4850
task_reschedules.router, prefix="/task-reschedules", tags=["Task Reschedules"]
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
import logging
21+
22+
from fastapi import APIRouter, HTTPException, status
23+
24+
from airflow.api_fastapi.common.db.common import SessionDep
25+
from airflow.api_fastapi.execution_api.datamodels.dags import DagStateResponse
26+
from airflow.models.dag import DagModel
27+
28+
router = APIRouter()
29+
30+
31+
log = logging.getLogger(__name__)
32+
33+
34+
@router.get(
35+
"/{dag_id}/state",
36+
responses={
37+
status.HTTP_404_NOT_FOUND: {"description": "DAG not found for the given dag_id"},
38+
},
39+
)
40+
def get_dag_state(
41+
dag_id: str,
42+
session: SessionDep,
43+
) -> DagStateResponse:
44+
"""Get a DAG Run State."""
45+
dag_model: DagModel = session.get(DagModel, dag_id)
46+
if not dag_model:
47+
raise HTTPException(
48+
status.HTTP_404_NOT_FOUND,
49+
detail={
50+
"reason": "not_found",
51+
"message": f"The Dag with dag_id: `{dag_id}` was not found",
52+
},
53+
)
54+
55+
is_paused = False if dag_model.is_paused is None else dag_model.is_paused
56+
57+
return DagStateResponse(is_paused=is_paused)
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
from __future__ import annotations
19+
20+
import pytest
21+
22+
from airflow.models import DagModel
23+
from airflow.providers.standard.operators.empty import EmptyOperator
24+
25+
from tests_common.test_utils.db import clear_db_runs
26+
27+
pytestmark = pytest.mark.db_test
28+
29+
30+
class TestDagState:
31+
def setup_method(self):
32+
clear_db_runs()
33+
34+
def teardown_method(self):
35+
clear_db_runs()
36+
37+
@pytest.mark.parametrize(
38+
"state, expected",
39+
[
40+
(True, True),
41+
(False, False),
42+
(None, False),
43+
],
44+
)
45+
def test_dag_is_paused(self, state, expected, client, session, dag_maker):
46+
"""Test DagState is active or paused"""
47+
48+
dag_id = "test_dag_is_paused"
49+
50+
with dag_maker(dag_id=dag_id, session=session, serialized=True):
51+
EmptyOperator(task_id="test_task")
52+
53+
session.query(DagModel).filter(DagModel.dag_id == dag_id).update({"is_paused": state})
54+
55+
session.commit()
56+
57+
response = client.get(
58+
f"/execution/dags/{dag_id}/state",
59+
)
60+
61+
assert response.status_code == 200
62+
assert response.json() == {"is_paused": expected}
63+
64+
def test_dag_not_found(self, client, session, dag_maker):
65+
"""Test Dag not found"""
66+
67+
dag_id = "test_dag_is_paused"
68+
69+
response = client.get(
70+
f"/execution/dags/{dag_id}/state",
71+
)
72+
73+
assert response.status_code == 404
74+
assert response.json() == {
75+
"detail": {
76+
"message": "The Dag with dag_id: `test_dag_is_paused` was not found",
77+
"reason": "not_found",
78+
}
79+
}

task-sdk/src/airflow/sdk/api/client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
ConnectionResponse,
4444
DagRunStateResponse,
4545
DagRunType,
46+
DagStateResponse,
4647
HITLDetailResponse,
4748
HITLUser,
4849
InactiveAssetsResponse,
@@ -721,6 +722,18 @@ def get_previous(
721722
return PreviousDagRunResult(dag_run=resp.json())
722723

723724

725+
class DagsOperations:
726+
__slots__ = ("client",)
727+
728+
def __init__(self, client: Client):
729+
self.client = client
730+
731+
def get_state(self, dag_id: str) -> DagStateResponse:
732+
"""Get the state of a Dag via the API server."""
733+
resp = self.client.get(f"dags/{dag_id}/state")
734+
return DagStateResponse.model_validate_json(resp.read())
735+
736+
724737
class HITLOperations:
725738
"""
726739
Operations related to Human in the loop. Require Airflow 3.1+.
@@ -931,6 +944,11 @@ def hitl(self):
931944
"""Operations related to HITL Responses."""
932945
return HITLOperations(self)
933946

947+
@lru_cache() # type: ignore[misc]
948+
@property
949+
def dags(self):
950+
return DagsOperations(self)
951+
934952

935953
# This is only used for parsing. ServerResponseError is raised instead
936954
class _ErrorBody(BaseModel):

task-sdk/src/airflow/sdk/api/datamodels/_generated.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,14 @@ class DagRunType(str, Enum):
154154
ASSET_TRIGGERED = "asset_triggered"
155155

156156

157+
class DagStateResponse(BaseModel):
158+
"""
159+
Schema for DAG State response.
160+
"""
161+
162+
is_paused: Annotated[bool, Field(title="Is Paused")]
163+
164+
157165
class HITLUser(BaseModel):
158166
"""
159167
Schema for a Human-in-the-loop users.

task-sdk/src/airflow/sdk/execution_time/comms.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,12 +585,18 @@ class HITLDetailRequestResult(HITLDetailRequest):
585585
type: Literal["HITLDetailRequestResult"] = "HITLDetailRequestResult"
586586

587587

588+
class DagStateResult(BaseModel):
589+
is_paused: bool
590+
type: Literal["DagStateResult"] = "DagStateResult"
591+
592+
588593
ToTask = Annotated[
589594
AssetResult
590595
| AssetEventsResult
591596
| ConnectionResult
592597
| DagRunStateResult
593598
| DRCount
599+
| DagStateResult
594600
| ErrorResponse
595601
| PrevSuccessfulDagRunResult
596602
| SentFDs
@@ -910,6 +916,11 @@ class MaskSecret(BaseModel):
910916
type: Literal["MaskSecret"] = "MaskSecret"
911917

912918

919+
class GetDagState(BaseModel):
920+
dag_id: str
921+
type: Literal["GetDagState"] = "GetDagState"
922+
923+
913924
ToSupervisor = Annotated[
914925
DeferTask
915926
| DeleteXCom
@@ -920,6 +931,7 @@ class MaskSecret(BaseModel):
920931
| GetConnection
921932
| GetDagRunState
922933
| GetDRCount
934+
| GetDagState
923935
| GetPrevSuccessfulDagRun
924936
| GetPreviousDagRun
925937
| GetTaskRescheduleStartDate

task-sdk/src/airflow/sdk/execution_time/supervisor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
GetAssetEventByAssetAlias,
8484
GetConnection,
8585
GetDagRunState,
86+
GetDagState,
8687
GetDRCount,
8788
GetPreviousDagRun,
8889
GetPrevSuccessfulDagRun,
@@ -1382,6 +1383,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
13821383
self.send_msg(resp, request_id=req_id, error=None, **dump_opts)
13831384
elif isinstance(msg, MaskSecret):
13841385
mask_secret(msg.value, msg.name)
1386+
elif isinstance(msg, GetDagState):
1387+
resp = self.client.dags.get_state(
1388+
dag_id=msg.dag_id,
1389+
)
13851390
else:
13861391
log.error("Unhandled request", msg=msg)
13871392
self.send_msg(

task-sdk/src/airflow/sdk/execution_time/task_runner.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,12 @@
6464
AssetEventDagRunReferenceResult,
6565
CommsDecoder,
6666
DagRunStateResult,
67+
DagStateResult,
6768
DeferTask,
6869
DRCount,
6970
ErrorResponse,
7071
GetDagRunState,
72+
GetDagState,
7173
GetDRCount,
7274
GetPreviousDagRun,
7375
GetTaskRescheduleStartDate,
@@ -550,6 +552,16 @@ def get_dagrun_state(dag_id: str, run_id: str) -> str:
550552

551553
return response.state
552554

555+
@staticmethod
556+
def get_dag_state(dag_id: str) -> DagStateResult:
557+
"""Return the state of the Dag run with the given Run ID."""
558+
response = SUPERVISOR_COMMS.send(msg=GetDagState(dag_id=dag_id))
559+
560+
if TYPE_CHECKING:
561+
assert isinstance(response, DagStateResult)
562+
563+
return response
564+
553565
@property
554566
def log_url(self) -> str:
555567
run_id = quote(self.run_id)

task-sdk/tests/task_sdk/api/test_client.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
ConnectionResponse,
3838
DagRunState,
3939
DagRunStateResponse,
40+
DagStateResponse,
4041
HITLDetailResponse,
4142
HITLUser,
4243
VariableResponse,
@@ -1349,3 +1350,21 @@ def handle_request(request: httpx.Request) -> httpx.Response:
13491350
assert result.params_input == {}
13501351
assert result.responded_by_user == HITLUser(id="admin", name="admin")
13511352
assert result.responded_at == timezone.datetime(2025, 7, 3, 0, 0, 0)
1353+
1354+
1355+
class TestDagsOperations:
1356+
def test_get_state(self):
1357+
"""Test that the client can get the state of a dag run"""
1358+
1359+
def handle_request(request: httpx.Request) -> httpx.Response:
1360+
if request.url.path == "/dags/test_dag/state":
1361+
return httpx.Response(
1362+
status_code=200,
1363+
json={"is_paused": False},
1364+
)
1365+
return httpx.Response(status_code=200)
1366+
1367+
client = make_client(transport=httpx.MockTransport(handle_request))
1368+
result = client.dags.get_state(dag_id="test_dag")
1369+
1370+
assert result == DagStateResponse(is_paused=False)

0 commit comments

Comments
 (0)