Skip to content

Commit 22e75b5

Browse files
Refactor devel-common/tests_common to use SQLA2 (#59849)
1 parent ab37151 commit 22e75b5

File tree

7 files changed

+129
-115
lines changed

7 files changed

+129
-115
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,10 @@ repos:
472472
^providers/edge3/.*\.py$|
473473
^providers/mysql/.*\.py$|
474474
^providers/openlineage/.*\.py$|
475-
^task_sdk.*\.py$
475+
^task_sdk.*\.py$|
476+
^devel-common/src/tests_common/pytest_plugin\.py$|
477+
^devel-common/src/tests_common/test_utils/.*\.py$
478+
476479
pass_filenames: true
477480
- id: update-supported-versions
478481
name: Updates supported versions in documentation

devel-common/src/tests_common/pytest_plugin.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1252,10 +1252,16 @@ def __call__(
12521252
self.bundle_name = bundle_name or "dag_maker"
12531253
self.bundle_version = bundle_version
12541254
if AIRFLOW_V_3_0_PLUS:
1255+
from sqlalchemy import func, select
1256+
12551257
from airflow.models.dagbundle import DagBundleModel
12561258

12571259
if (
1258-
self.session.query(DagBundleModel).filter(DagBundleModel.name == self.bundle_name).count()
1260+
self.session.scalar(
1261+
select(func.count())
1262+
.select_from(DagBundleModel)
1263+
.where(DagBundleModel.name == self.bundle_name)
1264+
)
12591265
== 0
12601266
):
12611267
self.session.add(DagBundleModel(name=self.bundle_name))
@@ -1285,39 +1291,25 @@ def cleanup(self):
12851291
self.session.rollback()
12861292

12871293
if AIRFLOW_V_3_0_PLUS:
1294+
from sqlalchemy import delete
1295+
12881296
from airflow.models.dag_version import DagVersion
12891297

1290-
self.session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids)).delete(
1291-
synchronize_session=False,
1292-
)
1293-
self.session.query(TaskInstance).filter(TaskInstance.dag_id.in_(dag_ids)).delete(
1294-
synchronize_session=False,
1295-
)
1296-
self.session.query(DagVersion).filter(DagVersion.dag_id.in_(dag_ids)).delete(
1297-
synchronize_session=False
1298-
)
1298+
self.session.execute(delete(DagRun).where(DagRun.dag_id.in_(dag_ids)))
1299+
self.session.execute(delete(TaskInstance).where(TaskInstance.dag_id.in_(dag_ids)))
1300+
self.session.execute(delete(DagVersion).where(DagVersion.dag_id.in_(dag_ids)))
12991301
else:
1300-
self.session.query(SerializedDagModel).filter(
1301-
SerializedDagModel.dag_id.in_(dag_ids)
1302-
).delete(synchronize_session=False)
1303-
self.session.query(DagRun).filter(DagRun.dag_id.in_(dag_ids)).delete(
1304-
synchronize_session=False,
1305-
)
1306-
self.session.query(TaskInstance).filter(TaskInstance.dag_id.in_(dag_ids)).delete(
1307-
synchronize_session=False,
1302+
from sqlalchemy import delete
1303+
1304+
self.session.execute(
1305+
delete(SerializedDagModel).where(SerializedDagModel.dag_id.in_(dag_ids))
13081306
)
1309-
self.session.query(XCom).filter(XCom.dag_id.in_(dag_ids)).delete(
1310-
synchronize_session=False,
1311-
)
1312-
self.session.query(DagModel).filter(DagModel.dag_id.in_(dag_ids)).delete(
1313-
synchronize_session=False,
1314-
)
1315-
self.session.query(TaskMap).filter(TaskMap.dag_id.in_(dag_ids)).delete(
1316-
synchronize_session=False,
1317-
)
1318-
self.session.query(AssetEvent).filter(AssetEvent.source_dag_id.in_(dag_ids)).delete(
1319-
synchronize_session=False,
1320-
)
1307+
self.session.execute(delete(DagRun).where(DagRun.dag_id.in_(dag_ids)))
1308+
self.session.execute(delete(TaskInstance).where(TaskInstance.dag_id.in_(dag_ids)))
1309+
self.session.execute(delete(XCom).where(XCom.dag_id.in_(dag_ids)))
1310+
self.session.execute(delete(DagModel).where(DagModel.dag_id.in_(dag_ids)))
1311+
self.session.execute(delete(TaskMap).where(TaskMap.dag_id.in_(dag_ids)))
1312+
self.session.execute(delete(AssetEvent).where(AssetEvent.source_dag_id.in_(dag_ids)))
13211313
self.session.commit()
13221314
if self._own_session:
13231315
self.session.expunge_all()
@@ -1714,10 +1706,12 @@ def _create_log_template(filename_template, elasticsearch_id=""):
17141706
session.commit()
17151707

17161708
def _delete_log_template():
1709+
from sqlalchemy import delete
1710+
17171711
from airflow.models import DagRun, TaskInstance
17181712

1719-
session.query(TaskInstance).delete()
1720-
session.query(DagRun).delete()
1713+
session.execute(delete(TaskInstance))
1714+
session.execute(delete(DagRun))
17211715
session.delete(log_template)
17221716
session.commit()
17231717

@@ -2753,11 +2747,18 @@ def testing_dag_bundle():
27532747
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
27542748

27552749
if AIRFLOW_V_3_0_PLUS:
2750+
from sqlalchemy import func, select
2751+
27562752
from airflow.models.dagbundle import DagBundleModel
27572753
from airflow.utils.session import create_session
27582754

27592755
with create_session() as session:
2760-
if session.query(DagBundleModel).filter(DagBundleModel.name == "testing").count() == 0:
2756+
if (
2757+
session.scalar(
2758+
select(func.count()).select_from(DagBundleModel).where(DagBundleModel.name == "testing")
2759+
)
2760+
== 0
2761+
):
27612762
testing = DagBundleModel(name="testing")
27622763
session.add(testing)
27632764

@@ -2767,11 +2768,13 @@ def testing_team():
27672768
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
27682769

27692770
if AIRFLOW_V_3_0_PLUS:
2771+
from sqlalchemy import select
2772+
27702773
from airflow.models.team import Team
27712774
from airflow.utils.session import create_session
27722775

27732776
with create_session() as session:
2774-
team = session.query(Team).filter_by(name="testing").one_or_none()
2777+
team = session.scalar(select(Team).where(Team.name == "testing"))
27752778
if not team:
27762779
team = Team(name="testing")
27772780
session.add(team)

devel-common/src/tests_common/test_utils/api_client_helpers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,15 @@ def create_airflow_connection(connection_id: str, connection_conf: dict[str, Any
114114
print(f"Connection '{connection_id}' does not exist. A new one will be created")
115115
create_connection_request(connection_id=connection_id, connection=connection_conf)
116116
else:
117+
from sqlalchemy import delete
118+
117119
from airflow.models import Connection
118120
from airflow.settings import Session
119121

120122
if Session is None:
121123
raise RuntimeError("Session not configured. Call configure_orm() first.")
122124
session = Session()
123-
query = session.query(Connection).filter(Connection.conn_id == connection_id)
124-
query.delete()
125+
session.execute(delete(Connection).where(Connection.conn_id == connection_id))
125126
connection = Connection(conn_id=connection_id, **connection_conf)
126127
session.add(connection)
127128
session.commit()
@@ -135,12 +136,13 @@ def delete_airflow_connection(connection_id: str) -> None:
135136
if AIRFLOW_V_3_0_PLUS:
136137
delete_connection_request(connection_id=connection_id)
137138
else:
139+
from sqlalchemy import delete
140+
138141
from airflow.models import Connection
139142
from airflow.settings import Session
140143

141144
if Session is None:
142145
raise RuntimeError("Session not configured. Call configure_orm() first.")
143146
session = Session()
144-
query = session.query(Connection).filter(Connection.conn_id == connection_id)
145-
query.delete()
147+
session.execute(delete(Connection).where(Connection.conn_id == connection_id))
146148
session.commit()

devel-common/src/tests_common/test_utils/api_fastapi.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,24 +49,25 @@ def _masked_value_check(data, sensitive_fields):
4949

5050

5151
def _check_last_log(session, dag_id, event, logical_date, expected_extra=None, check_masked=False):
52-
logs = (
53-
session.query(
52+
from sqlalchemy import select
53+
54+
logs = session.execute(
55+
select(
5456
Log.dag_id,
5557
Log.task_id,
5658
Log.event,
5759
Log.logical_date,
5860
Log.owner,
5961
Log.extra,
6062
)
61-
.filter(
63+
.where(
6264
Log.dag_id == dag_id,
6365
Log.event == event,
6466
Log.logical_date == logical_date,
6567
)
6668
.order_by(Log.dttm.desc())
6769
.limit(1)
68-
.all()
69-
)
70+
).all()
7071
assert len(logs) == 1
7172
assert logs[0].extra
7273
if expected_extra:
@@ -77,11 +78,10 @@ def _check_last_log(session, dag_id, event, logical_date, expected_extra=None, c
7778

7879

7980
def _check_dag_run_note(session, dr_id, note_data):
80-
dr_note = (
81-
session.query(DagRunNote)
82-
.join(DagRun, DagRunNote.dag_run_id == DagRun.id)
83-
.filter(DagRun.run_id == dr_id)
84-
.one_or_none()
81+
from sqlalchemy import select
82+
83+
dr_note = session.scalar(
84+
select(DagRunNote).join(DagRun, DagRunNote.dag_run_id == DagRun.id).where(DagRun.run_id == dr_id)
8585
)
8686
if note_data is None:
8787
assert dr_note is None
@@ -91,7 +91,9 @@ def _check_dag_run_note(session, dr_id, note_data):
9191

9292

9393
def _check_task_instance_note(session, ti_id, note_data):
94-
ti_note = session.query(TaskInstanceNote).filter_by(ti_id=ti_id).one_or_none()
94+
from sqlalchemy import select
95+
96+
ti_note = session.scalar(select(TaskInstanceNote).where(TaskInstanceNote.ti_id == ti_id))
9597
if note_data is None:
9698
assert ti_note is None
9799
else:

0 commit comments

Comments
 (0)