Skip to content

Commit 5fef698

Browse files
authored
Fix type hints (#61317)
1 parent 9f168b4 commit 5fef698

File tree

30 files changed

+152
-78
lines changed

30 files changed

+152
-78
lines changed

airflow-core/src/airflow/api_fastapi/auth/managers/base_auth_manager.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,15 @@
4747
from airflow.models import Connection, DagModel, Pool, Variable
4848
from airflow.models.dagbundle import DagBundleModel
4949
from airflow.models.team import Team, dag_bundle_team_association_table
50+
from airflow.typing_compat import Unpack
5051
from airflow.utils.log.logging_mixin import LoggingMixin
5152
from airflow.utils.session import NEW_SESSION, provide_session
5253

5354
if TYPE_CHECKING:
5455
from collections.abc import Sequence
5556

5657
from fastapi import FastAPI
58+
from sqlalchemy import Row
5759
from sqlalchemy.orm import Session
5860

5961
from airflow.api_fastapi.auth.managers.models.batch_apis import (
@@ -569,8 +571,9 @@ def get_authorized_dag_ids(
569571
isouter=True,
570572
)
571573
)
572-
rows = session.execute(stmt).all()
573-
dags_by_team: dict[str | None, set[str]] = defaultdict(set)
574+
# The below type annotation is acceptable on SQLA2.1, but not on 2.0
575+
rows: Sequence[Row[Unpack[tuple[str, str]]]] = session.execute(stmt).all() # type: ignore[type-arg]
576+
dags_by_team: dict[str, set[str]] = defaultdict(set)
574577
for dag_id, team_name in rows:
575578
dags_by_team[team_name].add(dag_id)
576579

airflow-core/src/airflow/api_fastapi/core_api/routes/public/assets.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,10 +74,12 @@
7474
AssetWatcherModel,
7575
TaskOutletAssetReference,
7676
)
77+
from airflow.typing_compat import Unpack
7778
from airflow.utils.state import DagRunState
7879
from airflow.utils.types import DagRunTriggeredByType, DagRunType
7980

8081
if TYPE_CHECKING:
82+
from sqlalchemy.engine import Result
8183
from sqlalchemy.sql import Select
8284

8385
assets_router = AirflowRouter(tags=["Asset"])
@@ -179,7 +181,8 @@ def get_assets(
179181
session=session,
180182
)
181183

182-
assets_rows = session.execute(
184+
# The below type annotation is acceptable on SQLA2.1, but not on 2.0
185+
assets_rows: Result[Unpack[tuple[AssetModel, int, datetime]]] = session.execute( # type: ignore[type-arg]
183186
assets_select.options(
184187
subqueryload(AssetModel.scheduled_dags),
185188
subqueryload(AssetModel.producing_tasks),

airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_stats.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
from __future__ import annotations
1919

20-
from typing import Annotated
20+
from typing import TYPE_CHECKING, Annotated
2121

2222
from fastapi import Depends, status
2323

@@ -41,8 +41,12 @@
4141
from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc
4242
from airflow.api_fastapi.core_api.security import ReadableDagRunsFilterDep, requires_access_dag
4343
from airflow.models.dagrun import DagRun
44+
from airflow.typing_compat import Unpack
4445
from airflow.utils.state import DagRunState
4546

47+
if TYPE_CHECKING:
48+
from sqlalchemy import Result
49+
4650
dag_stats_router = AirflowRouter(tags=["DagStats"], prefix="/dagStats")
4751

4852

@@ -71,7 +75,8 @@ def get_dag_stats(
7175
session=session,
7276
return_total_entries=False,
7377
)
74-
query_result = session.execute(dagruns_select)
78+
# The below type annotation is acceptable on SQLA2.1, but not on 2.0
79+
query_result: Result[Unpack[tuple[str, str, str, int]]] = session.execute(dagruns_select) # type: ignore[type-arg]
7580

7681
result_dag_ids = []
7782
dag_display_names: dict[str, str] = {}

airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_tags.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from __future__ import annotations
1919

20+
from collections.abc import Sequence
2021
from typing import Annotated
2122

2223
from fastapi import Depends
@@ -67,5 +68,5 @@ def get_dag_tags(
6768
limit=limit,
6869
session=session,
6970
)
70-
dag_tags = session.execute(dag_tags_select).scalars().all()
71+
dag_tags: Sequence = session.execute(dag_tags_select).scalars().all()
7172
return DAGTagCollectionResponse(tags=[x for x in dag_tags], total_entries=total_entries)

airflow-core/src/airflow/api_fastapi/core_api/routes/public/xcom.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,11 @@ def get_xcom_entry(
9393
# We use `BaseXCom.get_many` to fetch XComs directly from the database, bypassing the XCom Backend.
9494
# This avoids deserialization via the backend (e.g., from a remote storage like S3) and instead
9595
# retrieves the raw serialized value from the database.
96-
result = session.scalars(xcom_query).first()
96+
raw_result: tuple[XComModel] | None = session.scalars(xcom_query).first()
9797

98-
if result is None:
98+
if raw_result is None:
9999
raise HTTPException(status.HTTP_404_NOT_FOUND, f"XCom entry with key: `{xcom_key}` not found")
100+
result = raw_result[0] if isinstance(raw_result, tuple) else raw_result
100101

101102
item = copy.copy(result)
102103

airflow-core/src/airflow/api_fastapi/core_api/services/public/dag_run.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
if TYPE_CHECKING:
3636
from collections.abc import AsyncGenerator, Iterator
3737

38+
from sqlalchemy import ScalarResult
39+
3840

3941
@attrs.define
4042
class DagRunWaiter:
@@ -57,10 +59,12 @@ def _serialize_xcoms(self) -> dict[str, Any]:
5759
task_ids=self.result_task_ids,
5860
dag_ids=self.dag_id,
5961
)
60-
xcom_results = self.session.scalars(xcom_query.order_by(XComModel.task_id, XComModel.map_index))
62+
xcom_results: ScalarResult[tuple[XComModel]] = self.session.scalars(
63+
xcom_query.order_by(XComModel.task_id, XComModel.map_index)
64+
)
6165

62-
def _group_xcoms(g: Iterator[XComModel]) -> Any:
63-
entries = list(g)
66+
def _group_xcoms(g: Iterator[XComModel | tuple[XComModel]]) -> Any:
67+
entries = [row[0] if isinstance(row, tuple) else row for row in g]
6468
if len(entries) == 1 and entries[0].map_index < 0: # Unpack non-mapped task xcom.
6569
return entries[0].value
6670
return [entry.value for entry in entries] # Task is mapped; return all xcoms in a list.

airflow-core/src/airflow/api_fastapi/execution_api/routes/dag_runs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def get_dagrun_state(
190190
) -> DagRunStateResponse:
191191
"""Get a Dag run State."""
192192
try:
193-
state = session.scalars(
193+
state: DagRunState = session.scalars(
194194
select(DagRunModel.state).where(DagRunModel.dag_id == dag_id, DagRunModel.run_id == run_id)
195195
).one()
196196
except NoResultFound:

airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ def get_mapped_xcom_by_index(
113113
else:
114114
xcom_query = xcom_query.order_by(XComModel.map_index.desc()).offset(-1 - offset)
115115

116+
result: tuple[XComModel] | None
116117
if (result := session.scalars(xcom_query).first()) is None:
117118
message = (
118119
f"XCom with {key=} {offset=} not found for task {task_id!r} in DAG run {run_id!r} of {dag_id!r}"
@@ -121,7 +122,7 @@ def get_mapped_xcom_by_index(
121122
status_code=status.HTTP_404_NOT_FOUND,
122123
detail={"reason": "not_found", "message": message},
123124
)
124-
return XComSequenceIndexResponse(result.value)
125+
return XComSequenceIndexResponse((result[0] if isinstance(result, tuple) else result).value)
125126

126127

127128
class GetXComSliceFilterParams(BaseModel):
@@ -291,8 +292,8 @@ def get_xcom(
291292
# retrieves the raw serialized value from the database. By not relying on `XCom.get_many` or `XCom.get_one`
292293
# (which automatically deserializes using the backend), we avoid potential
293294
# performance hits from retrieving large data files into the API server.
294-
result = session.scalars(xcom_query).first()
295-
if result is None:
295+
result: tuple[XComModel] | None
296+
if (result := session.scalars(xcom_query).first()) is None:
296297
if params.offset is None:
297298
message = (
298299
f"XCom with {key=} map_index={params.map_index} not found for "
@@ -308,7 +309,7 @@ def get_xcom(
308309
detail={"reason": "not_found", "message": message},
309310
)
310311

311-
return XComResponse(key=key, value=result.value)
312+
return XComResponse(key=key, value=(result[0] if isinstance(result, tuple) else result).value)
312313

313314

314315
# TODO: once we have JWT tokens, then remove dag_id/run_id/task_id from the URL and just use the info in

airflow-core/src/airflow/dag_processing/collection.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from __future__ import annotations
2929

3030
import traceback
31-
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, cast
31+
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar
3232

3333
import structlog
3434
from sqlalchemy import delete, func, insert, select, tuple_, update
@@ -76,7 +76,7 @@
7676
from sqlalchemy.sql import Select
7777

7878
from airflow.models.dagwarning import DagWarning
79-
from airflow.typing_compat import Self
79+
from airflow.typing_compat import Self, Unpack
8080

8181
AssetT = TypeVar("AssetT", SerializedAsset, SerializedAssetAlias)
8282

@@ -512,15 +512,18 @@ class DagModelOperation(NamedTuple):
512512

513513
def find_orm_dags(self, *, session: Session) -> dict[str, DagModel]:
514514
"""Find existing DagModel objects from DAG objects."""
515-
stmt = (
516-
select(DagModel)
517-
.options(joinedload(DagModel.tags, innerjoin=False))
518-
.where(DagModel.dag_id.in_(self.dags))
519-
.options(joinedload(DagModel.schedule_asset_references))
520-
.options(joinedload(DagModel.schedule_asset_alias_references))
521-
.options(joinedload(DagModel.task_outlet_asset_references))
515+
stmt: Select[Unpack[tuple[DagModel]]] = with_row_locks(
516+
(
517+
select(DagModel)
518+
.options(joinedload(DagModel.tags, innerjoin=False))
519+
.where(DagModel.dag_id.in_(self.dags))
520+
.options(joinedload(DagModel.schedule_asset_references))
521+
.options(joinedload(DagModel.schedule_asset_alias_references))
522+
.options(joinedload(DagModel.task_outlet_asset_references))
523+
),
524+
of=DagModel,
525+
session=session,
522526
)
523-
stmt = cast("Select[tuple[DagModel]]", with_row_locks(stmt, of=DagModel, session=session))
524527
return {dm.dag_id: dm for dm in session.scalars(stmt).unique()}
525528

526529
def add_dags(self, *, session: Session) -> dict[str, DagModel]:
@@ -711,7 +714,7 @@ def _find_all_asset_aliases(dags: Iterable[LazyDeserializedDAG]) -> Iterator[Ser
711714

712715
def _find_active_assets(name_uri_assets: Iterable[tuple[str, str]], session: Session) -> set[tuple[str, str]]:
713716
return {
714-
tuple(row)
717+
(str(row[0]), str(row[1]))
715718
for row in session.execute(
716719
select(AssetModel.name, AssetModel.uri).where(
717720
tuple_(AssetModel.name, AssetModel.uri).in_(name_uri_assets),
@@ -906,7 +909,7 @@ def _add_dag_asset_references(
906909
if not references:
907910
return
908911
orm_refs = {
909-
tuple(row)
912+
(str(row[0]), str(row[1]))
910913
for row in session.execute(
911914
select(model.dag_id, getattr(model, attr)).where(
912915
model.dag_id.in_(dag_id for dag_id, _ in references)

airflow-core/src/airflow/dag_processing/manager.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@
7575
from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks
7676

7777
if TYPE_CHECKING:
78-
from collections.abc import Callable, Iterable, Iterator
78+
from collections.abc import Callable, Iterable, Iterator, Sequence
7979
from socket import socket
8080

8181
from sqlalchemy.orm import Session
@@ -497,15 +497,17 @@ def _fetch_callbacks(
497497
callback_queue: list[CallbackRequest] = []
498498
with prohibit_commit(session) as guard:
499499
bundle_names = [bundle.name for bundle in self._dag_bundles]
500-
query: Select[tuple[DbCallbackRequest]] = select(DbCallbackRequest)
501-
query = query.order_by(DbCallbackRequest.priority_weight.desc()).limit(
502-
self.max_callbacks_per_loop
503-
)
504-
query = cast(
505-
"Select[tuple[DbCallbackRequest]]",
506-
with_row_locks(query, of=DbCallbackRequest, session=session, skip_locked=True),
500+
query: Select[tuple[DbCallbackRequest]] = with_row_locks(
501+
select(DbCallbackRequest)
502+
.order_by(DbCallbackRequest.priority_weight.desc())
503+
.limit(self.max_callbacks_per_loop),
504+
of=DbCallbackRequest,
505+
session=session,
506+
skip_locked=True,
507507
)
508-
callbacks = session.scalars(query)
508+
callbacks: Sequence[DbCallbackRequest] = [
509+
cb[0] if isinstance(cb, tuple) else cb for cb in session.scalars(query)
510+
]
509511
for callback in callbacks:
510512
req = callback.get_callback_request()
511513
if req.bundle_name not in bundle_names:

0 commit comments

Comments
 (0)