Skip to content

Commit d7b440b

Browse files
authored
SQLA2: fix mypy issue with getting the dialect name (apache#56941)
* SQLA: add util to get the dialect name type-safely * Replace unsafe dialect retrievals with `get_dialect_name`
1 parent 02fef9d commit d7b440b

File tree

9 files changed

+73
-29
lines changed

9 files changed

+73
-29
lines changed

airflow-core/src/airflow/api_fastapi/core_api/services/ui/calendar.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from airflow.timetables._cron import CronMixin
4040
from airflow.timetables.base import DataInterval, TimeRestriction
4141
from airflow.timetables.simple import ContinuousTimetable
42+
from airflow.utils.sqlalchemy import get_dialect_name
4243

4344
log = structlog.get_logger(logger_name=__name__)
4445

@@ -92,7 +93,7 @@ def _get_historical_dag_runs(
9293
granularity: Literal["hourly", "daily"],
9394
) -> tuple[list[CalendarTimeRangeResponse], Sequence[Row]]:
9495
"""Get historical DAG runs from the database."""
95-
dialect = session.bind.dialect.name
96+
dialect = get_dialect_name(session)
9697

9798
time_expression = self._get_time_truncation_expression(DagRun.logical_date, granularity, dialect)
9899

airflow-core/src/airflow/assets/manager.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from airflow.stats import Stats
4040
from airflow.utils.log.logging_mixin import LoggingMixin
41+
from airflow.utils.sqlalchemy import get_dialect_name
4142

4243
if TYPE_CHECKING:
4344
from sqlalchemy.orm.session import Session
@@ -245,7 +246,7 @@ def _queue_dagruns(cls, asset_id: int, dags_to_queue: set[DagModel], session: Se
245246
if not dags_to_queue:
246247
return
247248

248-
if session.bind.dialect.name == "postgresql":
249+
if get_dialect_name(session) == "postgresql":
249250
return cls._postgres_queue_dagruns(asset_id, dags_to_queue, session)
250251
return cls._slow_path_queue_dagruns(asset_id, dags_to_queue, session)
251252

airflow-core/src/airflow/dag_processing/collection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from airflow.serialization.serialized_objects import BaseSerialization, LazyDeserializedDAG, SerializedDAG
6161
from airflow.triggers.base import BaseEventTrigger
6262
from airflow.utils.retries import MAX_DB_RETRIES, run_with_db_retries
63-
from airflow.utils.sqlalchemy import with_row_locks
63+
from airflow.utils.sqlalchemy import get_dialect_name, with_row_locks
6464
from airflow.utils.types import DagRunType
6565

6666
if TYPE_CHECKING:
@@ -756,7 +756,7 @@ def activate_assets_if_possible(self, models: Iterable[AssetModel], *, session:
756756
there's a conflict. The scheduler makes a more comprehensive pass
757757
through all assets in ``_update_asset_orphanage``.
758758
"""
759-
if session.bind is not None and (dialect_name := session.bind.dialect.name) == "postgresql":
759+
if (dialect_name := get_dialect_name(session)) == "postgresql":
760760
from sqlalchemy.dialects.postgresql import insert as postgresql_insert
761761

762762
stmt: Any = postgresql_insert(AssetActive).on_conflict_do_nothing()

airflow-core/src/airflow/models/dagrun.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,14 @@
7979
from airflow.utils.retries import retry_db_transaction
8080
from airflow.utils.session import NEW_SESSION, provide_session
8181
from airflow.utils.span_status import SpanStatus
82-
from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, mapped_column, nulls_first, with_row_locks
82+
from airflow.utils.sqlalchemy import (
83+
ExtendedJSON,
84+
UtcDateTime,
85+
get_dialect_name,
86+
mapped_column,
87+
nulls_first,
88+
with_row_locks,
89+
)
8390
from airflow.utils.state import DagRunState, State, TaskInstanceState
8491
from airflow.utils.strings import get_random_string
8592
from airflow.utils.thread_safe_dict import ThreadSafeDict
@@ -399,7 +406,7 @@ def duration(self) -> float | None:
399406
@duration.expression # type: ignore[no-redef]
400407
@provide_session
401408
def duration(cls, session: Session = NEW_SESSION) -> Case:
402-
dialect_name = session.bind.dialect.name
409+
dialect_name = get_dialect_name(session)
403410
if dialect_name == "mysql":
404411
return func.timestampdiff(text("SECOND"), cls.start_date, cls.end_date)
405412

airflow-core/src/airflow/models/deadline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
from airflow.triggers.deadline import PAYLOAD_BODY_KEY, PAYLOAD_STATUS_KEY, DeadlineCallbackTrigger
4141
from airflow.utils.log.logging_mixin import LoggingMixin
4242
from airflow.utils.session import provide_session
43-
from airflow.utils.sqlalchemy import UtcDateTime, mapped_column
43+
from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name, mapped_column
4444

4545
if TYPE_CHECKING:
4646
from sqlalchemy.orm import Session
@@ -411,7 +411,7 @@ def _evaluate_with(self, *, session: Session, **kwargs: Any) -> datetime | None:
411411
dag_id = kwargs["dag_id"]
412412

413413
# Get database dialect to use appropriate time difference calculation
414-
dialect = getattr(session.bind.dialect, "name", None)
414+
dialect = get_dialect_name(session)
415415

416416
# Create database-specific expression for calculating duration in seconds
417417
if dialect == "postgresql":

airflow-core/src/airflow/models/serialized_dag.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from airflow.settings import COMPRESS_SERIALIZED_DAGS, json
5050
from airflow.utils.hashlib_wrapper import md5
5151
from airflow.utils.session import NEW_SESSION, provide_session
52-
from airflow.utils.sqlalchemy import UtcDateTime, mapped_column
52+
from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name, mapped_column
5353

5454
if TYPE_CHECKING:
5555
from sqlalchemy.orm import Session
@@ -591,12 +591,13 @@ def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str, list[
591591
"""
592592
load_json: Callable | None
593593
if COMPRESS_SERIALIZED_DAGS is False:
594-
if session.bind.dialect.name in ["sqlite", "mysql"]:
594+
dialect = get_dialect_name(session)
595+
if dialect in ["sqlite", "mysql"]:
595596
data_col_to_select = func.json_extract(cls._data, "$.dag.dag_dependencies")
596597

597598
def load_json(deps_data):
598599
return json.loads(deps_data) if deps_data else []
599-
elif session.bind.dialect.name == "postgresql":
600+
elif dialect == "postgresql":
600601
# Use #> operator which works for both JSON and JSONB types
601602
# Returns the JSON sub-object at the specified path
602603
data_col_to_select = cls._data.op("#>")(literal('{"dag","dag_dependencies"}'))

airflow-core/src/airflow/models/trigger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from airflow.triggers.base import BaseTaskEndEvent
3838
from airflow.utils.retries import run_with_db_retries
3939
from airflow.utils.session import NEW_SESSION, provide_session
40-
from airflow.utils.sqlalchemy import UtcDateTime, mapped_column, with_row_locks
40+
from airflow.utils.sqlalchemy import UtcDateTime, get_dialect_name, mapped_column, with_row_locks
4141
from airflow.utils.state import TaskInstanceState
4242

4343
if TYPE_CHECKING:
@@ -231,7 +231,7 @@ def clean_unused(cls, session: Session = NEW_SESSION) -> None:
231231
.group_by(cls.id)
232232
.having(func.count(TaskInstance.trigger_id) == 0)
233233
)
234-
if session.bind.dialect.name == "mysql":
234+
if get_dialect_name(session) == "mysql":
235235
# MySQL doesn't support DELETE with JOIN, so we need to do it in two steps
236236
ids = session.scalars(ids).all()
237237
session.execute(

airflow-core/src/airflow/utils/sqlalchemy.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,13 @@ def mapped_column(*args, **kwargs):
5858
return Column(*args, **kwargs)
5959

6060

61+
def get_dialect_name(session: Session) -> str | None:
62+
"""Safely get the name of the dialect associated with the given session."""
63+
if (bind := session.get_bind()) is None:
64+
raise ValueError("No bind/engine is associated with the provided Session")
65+
return getattr(bind.dialect, "name", None)
66+
67+
6168
class UtcDateTime(TypeDecorator):
6269
"""
6370
Similar to :class:`~sqlalchemy.types.TIMESTAMP` with ``timezone=True`` option, with some differences.
@@ -312,7 +319,7 @@ def nulls_first(col, session: Session) -> dict[str, Any]:
312319
Other databases do not need it since NULL values are considered lower than
313320
any other values, and appear first when the order is ASC (ascending).
314321
"""
315-
if session.bind.dialect.name == "postgresql":
322+
if get_dialect_name(session) == "postgresql":
316323
return nullsfirst(col)
317324
return col
318325

airflow-core/tests/unit/utils/test_sqlalchemy.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import pickle
2222
from copy import deepcopy
2323
from unittest import mock
24-
from unittest.mock import MagicMock
2524

2625
import pytest
2726
from kubernetes.client import models as k8s
@@ -37,6 +36,7 @@
3736
from airflow.utils.sqlalchemy import (
3837
ExecutorConfigType,
3938
ensure_pod_is_valid_after_unpickling,
39+
get_dialect_name,
4040
is_sqlalchemy_v1,
4141
prohibit_commit,
4242
with_row_locks,
@@ -52,13 +52,40 @@
5252
TEST_POD = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")]))
5353

5454

55+
class TestGetDialectName:
56+
def test_returns_dialect_name_when_present(self, mocker):
57+
mock_session = mocker.Mock()
58+
mock_bind = mocker.Mock()
59+
mock_bind.dialect.name = "postgresql"
60+
mock_session.get_bind.return_value = mock_bind
61+
62+
assert get_dialect_name(mock_session) == "postgresql"
63+
64+
def test_raises_when_no_bind(self, mocker):
65+
mock_session = mocker.Mock()
66+
mock_session.get_bind.return_value = None
67+
68+
with pytest.raises(ValueError, match="No bind/engine is associated"):
69+
get_dialect_name(mock_session)
70+
71+
def test_returns_none_when_dialect_has_no_name(self, mocker):
72+
mock_session = mocker.Mock()
73+
mock_bind = mocker.Mock()
74+
# simulate dialect object without `name` attribute
75+
mock_bind.dialect = mock.Mock()
76+
delattr(mock_bind.dialect, "name") if hasattr(mock_bind.dialect, "name") else None
77+
mock_session.get_bind.return_value = mock_bind
78+
79+
assert get_dialect_name(mock_session) is None
80+
81+
5582
class TestSqlAlchemyUtils:
5683
def setup_method(self):
5784
session = Session()
5885

5986
# make sure NOT to run in UTC. Only postgres supports storing
6087
# timezone information in the datetime field
61-
if session.bind.dialect.name == "postgresql":
88+
if get_dialect_name(session) == "postgresql":
6289
session.execute(text("SET timezone='Europe/Amsterdam'"))
6390

6491
self.session = session
@@ -124,7 +151,7 @@ def test_process_bind_param_naive(self):
124151
dag.clear()
125152

126153
@pytest.mark.parametrize(
127-
"dialect, supports_for_update_of, use_row_level_lock_conf, expected_use_row_level_lock",
154+
("dialect", "supports_for_update_of", "use_row_level_lock_conf", "expected_use_row_level_lock"),
128155
[
129156
("postgresql", True, True, True),
130157
("postgresql", True, False, False),
@@ -192,7 +219,7 @@ def teardown_method(self):
192219

193220
class TestExecutorConfigType:
194221
@pytest.mark.parametrize(
195-
"input, expected",
222+
("input", "expected"),
196223
[
197224
("anything", "anything"),
198225
(
@@ -206,13 +233,13 @@ class TestExecutorConfigType:
206233
),
207234
],
208235
)
209-
def test_bind_processor(self, input, expected):
236+
def test_bind_processor(self, input, expected, mocker):
210237
"""
211238
The returned bind processor should pickle the object as is, unless it is a dictionary with
212239
a pod_override node, in which case it should run it through BaseSerialization.
213240
"""
214241
config_type = ExecutorConfigType()
215-
mock_dialect = MagicMock()
242+
mock_dialect = mocker.MagicMock()
216243
mock_dialect.dbapi = None
217244
process = config_type.bind_processor(mock_dialect)
218245
assert pickle.loads(process(input)) == expected
@@ -239,13 +266,13 @@ def test_bind_processor(self, input, expected):
239266
),
240267
],
241268
)
242-
def test_result_processor(self, input):
269+
def test_result_processor(self, input, mocker):
243270
"""
244271
The returned bind processor should pickle the object as is, unless it is a dictionary with
245272
a pod_override node whose value was serialized with BaseSerialization.
246273
"""
247274
config_type = ExecutorConfigType()
248-
mock_dialect = MagicMock()
275+
mock_dialect = mocker.MagicMock()
249276
mock_dialect.dbapi = None
250277
process = config_type.result_processor(mock_dialect, None)
251278
result = process(input)
@@ -277,7 +304,7 @@ def __eq__(self, other):
277304
assert instance.compare_values(a, a) is False
278305
assert instance.compare_values("a", "a") is True
279306

280-
def test_result_processor_bad_pickled_obj(self):
307+
def test_result_processor_bad_pickled_obj(self, mocker):
281308
"""
282309
If unpickled obj is missing attrs that curr lib expects
283310
"""
@@ -309,7 +336,7 @@ def test_result_processor_bad_pickled_obj(self):
309336

310337
# get the result processor method
311338
config_type = ExecutorConfigType()
312-
mock_dialect = MagicMock()
339+
mock_dialect = mocker.MagicMock()
313340
mock_dialect.dbapi = None
314341
process = config_type.result_processor(mock_dialect, None)
315342

@@ -322,13 +349,13 @@ def test_result_processor_bad_pickled_obj(self):
322349

323350

324351
@pytest.mark.parametrize(
325-
"mock_version, expected_result",
352+
("mock_version", "expected_result"),
326353
[
327354
("1.0.0", True), # Test 1: v1 identified as v1
328355
("2.3.4", False), # Test 2: v2 not identified as v1
329356
],
330357
)
331-
def test_is_sqlalchemy_v1(mock_version, expected_result):
332-
with mock.patch("airflow.utils.sqlalchemy.metadata") as mock_metadata:
333-
mock_metadata.version.return_value = mock_version
334-
assert is_sqlalchemy_v1() == expected_result
358+
def test_is_sqlalchemy_v1(mock_version, expected_result, mocker):
359+
mock_metadata = mocker.patch("airflow.utils.sqlalchemy.metadata")
360+
mock_metadata.version.return_value = mock_version
361+
assert is_sqlalchemy_v1() == expected_result

0 commit comments

Comments
 (0)