Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

from airflow.api_fastapi.core_api.base import BaseModel


class DagStateResponse(BaseModel):
"""Schema for DAG State response."""

is_paused: bool
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
assets,
connections,
dag_runs,
dags,
health,
hitl,
task_instances,
Expand All @@ -43,6 +44,7 @@
authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"])
authenticated_router.include_router(connections.router, prefix="/connections", tags=["Connections"])
authenticated_router.include_router(dag_runs.router, prefix="/dag-runs", tags=["Dag Runs"])
authenticated_router.include_router(dags.router, prefix="/dags", tags=["Dags"])
authenticated_router.include_router(task_instances.router, prefix="/task-instances", tags=["Task Instances"])
authenticated_router.include_router(
task_reschedules.router, prefix="/task-reschedules", tags=["Task Reschedules"]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging

from fastapi import APIRouter, HTTPException, status

from airflow.api_fastapi.common.db.common import SessionDep
from airflow.api_fastapi.execution_api.datamodels.dags import DagStateResponse
from airflow.models.dag import DagModel

router = APIRouter()


log = logging.getLogger(__name__)


@router.get(
"/{dag_id}/state",
responses={
status.HTTP_404_NOT_FOUND: {"description": "DAG not found for the given dag_id"},
},
)
def get_dag_state(
dag_id: str,
session: SessionDep,
) -> DagStateResponse:
"""Get a DAG Run State."""
dag_model: DagModel = session.get(DagModel, dag_id)
if not dag_model:
raise HTTPException(
status.HTTP_404_NOT_FOUND,
detail={
"reason": "not_found",
"message": f"The Dag with dag_id: `{dag_id}` was not found",
},
)

is_paused = False if dag_model.is_paused is None else dag_model.is_paused

return DagStateResponse(is_paused=is_paused)
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import pytest

from airflow.models import DagModel
from airflow.providers.standard.operators.empty import EmptyOperator

from tests_common.test_utils.db import clear_db_runs

pytestmark = pytest.mark.db_test


class TestDagState:
def setup_method(self):
clear_db_runs()

def teardown_method(self):
clear_db_runs()

@pytest.mark.parametrize(
("state", "expected"),
[
pytest.param(True, True),
pytest.param(False, False),
pytest.param(None, False),
],
)
def test_dag_is_paused(self, client, session, dag_maker, state, expected):
"""Test DagState is active or paused"""

dag_id = "test_dag_is_paused"

with dag_maker(dag_id=dag_id, session=session, serialized=True):
EmptyOperator(task_id="test_task")

session.query(DagModel).filter(DagModel.dag_id == dag_id).update({"is_paused": state})

session.commit()

response = client.get(
f"/execution/dags/{dag_id}/state",
)

assert response.status_code == 200
assert response.json() == {"is_paused": expected}

def test_dag_not_found(self, client, session, dag_maker):
"""Test Dag not found"""

dag_id = "test_dag_is_paused"

response = client.get(
f"/execution/dags/{dag_id}/state",
)

assert response.status_code == 404
assert response.json() == {
"detail": {
"message": "The Dag with dag_id: `test_dag_is_paused` was not found",
"reason": "not_found",
}
}
1 change: 1 addition & 0 deletions airflow-core/tests/unit/dag_processing/test_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,7 @@ def get_type_names(union_type):
"GetAssetEventByAssetAlias",
"GetDagRun",
"GetDagRunState",
"GetDagState",
"GetDRCount",
"GetTaskBreadcrumbs",
"GetTaskRescheduleStartDate",
Expand Down
2 changes: 2 additions & 0 deletions airflow-core/tests/unit/jobs/test_triggerer_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -1219,6 +1219,7 @@ def get_type_names(union_type):
"ResendLoggingFD",
"CreateHITLDetailPayload",
"SetRenderedMapIndex",
"GetDagState",
}

in_task_but_not_in_trigger_runner = {
Expand All @@ -1238,6 +1239,7 @@ def get_type_names(union_type):
"PreviousDagRunResult",
"PreviousTIResult",
"HITLDetailRequestResult",
"DagStateResult",
}

supervisor_diff = (
Expand Down
18 changes: 18 additions & 0 deletions task-sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
DagRun,
DagRunStateResponse,
DagRunType,
DagStateResponse,
HITLDetailRequest,
HITLDetailResponse,
HITLUser,
Expand Down Expand Up @@ -769,6 +770,18 @@ def get_previous(
return PreviousDagRunResult(dag_run=resp.json())


class DagsOperations:
__slots__ = ("client",)

def __init__(self, client: Client):
self.client = client

def get_state(self, dag_id: str) -> DagStateResponse:
"""Get the state of a Dag via the API server."""
resp = self.client.get(f"dags/{dag_id}/state")
return DagStateResponse.model_validate_json(resp.read())


class HITLOperations:
"""
Operations related to Human in the loop. Require Airflow 3.1+.
Expand Down Expand Up @@ -1001,6 +1014,11 @@ def hitl(self):
"""Operations related to HITL Responses."""
return HITLOperations(self)

@lru_cache() # type: ignore[misc]
@property
def dags(self):
return DagsOperations(self)


# This is only used for parsing. ServerResponseError is raised instead
class _ErrorBody(BaseModel):
Expand Down
8 changes: 8 additions & 0 deletions task-sdk/src/airflow/sdk/api/datamodels/_generated.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ class DagRunType(str, Enum):
ASSET_TRIGGERED = "asset_triggered"


class DagStateResponse(BaseModel):
"""
Schema for DAG State response.
"""

is_paused: Annotated[bool, Field(title="Is Paused")]


class HITLUser(BaseModel):
"""
Schema for a Human-in-the-loop users.
Expand Down
12 changes: 12 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,13 +654,19 @@ def from_api_response(cls, hitl_request: HITLDetailRequest) -> HITLDetailRequest
return cls(**hitl_request.model_dump(exclude_defaults=True), type="HITLDetailRequestResult")


class DagStateResult(BaseModel):
is_paused: bool
type: Literal["DagStateResult"] = "DagStateResult"


ToTask = Annotated[
AssetResult
| AssetEventsResult
| ConnectionResult
| DagRunResult
| DagRunStateResult
| DRCount
| DagStateResult
| ErrorResponse
| PrevSuccessfulDagRunResult
| PreviousTIResult
Expand Down Expand Up @@ -978,6 +984,11 @@ class MaskSecret(BaseModel):
type: Literal["MaskSecret"] = "MaskSecret"


class GetDagState(BaseModel):
dag_id: str
type: Literal["GetDagState"] = "GetDagState"


ToSupervisor = Annotated[
DeferTask
| DeleteXCom
Expand All @@ -989,6 +1000,7 @@ class MaskSecret(BaseModel):
| GetDagRun
| GetDagRunState
| GetDRCount
| GetDagState
| GetPrevSuccessfulDagRun
| GetPreviousDagRun
| GetPreviousTI
Expand Down
5 changes: 5 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
GetConnection,
GetDagRun,
GetDagRunState,
GetDagState,
GetDRCount,
GetPreviousDagRun,
GetPreviousTI,
Expand Down Expand Up @@ -1459,6 +1460,10 @@ def _handle_request(self, msg: ToSupervisor, log: FilteringBoundLogger, req_id:
dump_opts = {"exclude_unset": True}
elif isinstance(msg, MaskSecret):
mask_secret(msg.value, msg.name)
elif isinstance(msg, GetDagState):
resp = self.client.dags.get_state(
dag_id=msg.dag_id,
)
else:
log.error("Unhandled request", msg=msg)
self.send_msg(
Expand Down
12 changes: 12 additions & 0 deletions task-sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,12 @@
AssetEventDagRunReferenceResult,
CommsDecoder,
DagRunStateResult,
DagStateResult,
DeferTask,
DRCount,
ErrorResponse,
GetDagRunState,
GetDagState,
GetDRCount,
GetPreviousDagRun,
GetPreviousTI,
Expand Down Expand Up @@ -620,6 +622,16 @@ def get_dagrun_state(dag_id: str, run_id: str) -> str:

return response.state

@staticmethod
def get_dag_state(dag_id: str) -> DagStateResult:
"""Return the state of the Dag run with the given Run ID."""
response = SUPERVISOR_COMMS.send(msg=GetDagState(dag_id=dag_id))

if TYPE_CHECKING:
assert isinstance(response, DagStateResult)

return response

@property
def log_url(self) -> str:
run_id = quote(self.run_id)
Expand Down
19 changes: 19 additions & 0 deletions task-sdk/tests/task_sdk/api/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
ConnectionResponse,
DagRunState,
DagRunStateResponse,
DagStateResponse,
HITLDetailRequest,
HITLDetailResponse,
HITLUser,
Expand Down Expand Up @@ -1536,3 +1537,21 @@ def test_cache_miss_on_different_parameters(self):
assert ctx1 is not ctx2
assert info.misses == 2
assert info.currsize == 2


class TestDagsOperations:
def test_get_state(self):
"""Test that the client can get the state of a dag run"""

def handle_request(request: httpx.Request) -> httpx.Response:
if request.url.path == "/dags/test_dag/state":
return httpx.Response(
status_code=200,
json={"is_paused": False},
)
return httpx.Response(status_code=200)

client = make_client(transport=httpx.MockTransport(handle_request))
result = client.dags.get_state(dag_id="test_dag")

assert result == DagStateResponse(is_paused=False)
Loading
Loading