Skip to content

Commit 27af41f

Browse files
Refactor models tests to use SQLA2 (#59837)
Migrate deprecated SQLAlchemy Query object usage to SQLAlchemy 2.0 style in the following test files: - test_dag.py - test_dagcode.py - test_mappedoperator.py - test_taskinstance.py - test_variable.py Closes #59402
1 parent 2a9b4cf commit 27af41f

File tree

6 files changed

+181
-152
lines changed

6 files changed

+181
-152
lines changed

.pre-commit-config.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,11 @@ repos:
433433
^airflow-core/tests/unit/models/test_cleartasks.py$|
434434
^airflow-core/tests/unit/models/test_xcom.py$|
435435
^airflow-core/tests/unit/models/test_dagrun.py$|
436+
^airflow-core/tests/unit/models/test_dag\.py$|
437+
^airflow-core/tests/unit/models/test_dagcode\.py$|
438+
^airflow-core/tests/unit/models/test_mappedoperator\.py$|
439+
^airflow-core/tests/unit/models/test_taskinstance\.py$|
440+
^airflow-core/tests/unit/models/test_variable\.py$|
436441
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_sources.py$|
437442
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py$|
438443
^airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py$|

airflow-core/tests/unit/models/test_dag.py

Lines changed: 60 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import pendulum
3434
import pytest
3535
import time_machine
36-
from sqlalchemy import inspect, select
36+
from sqlalchemy import delete, inspect, select, update
3737

3838
from airflow import settings
3939
from airflow._shared.module_loading import qualname
@@ -386,7 +386,8 @@ def test_dagtag_repr(self, testing_dag_bundle):
386386
sync_dag_to_db(dag)
387387
with create_session() as session:
388388
assert {"tag-1", "tag-2"} == {
389-
repr(t) for t in session.query(DagTag).filter(DagTag.dag_id == "dag-test-dagtag").all()
389+
repr(t)
390+
for t in session.scalars(select(DagTag).where(DagTag.dag_id == "dag-test-dagtag")).all()
390391
}
391392

392393
def test_bulk_write_to_db(self, testing_dag_bundle):
@@ -402,16 +403,16 @@ def test_bulk_write_to_db(self, testing_dag_bundle):
402403
SerializedDAG.bulk_write_to_db("testing", None, dags)
403404
with create_session() as session:
404405
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
405-
row[0] for row in session.query(DagModel.dag_id).all()
406+
row[0] for row in session.execute(select(DagModel.dag_id)).all()
406407
}
407408
assert {
408409
("dag-bulk-sync-0", "test-dag"),
409410
("dag-bulk-sync-1", "test-dag"),
410411
("dag-bulk-sync-2", "test-dag"),
411412
("dag-bulk-sync-3", "test-dag"),
412-
} == set(session.query(DagTag.dag_id, DagTag.name).all())
413+
} == set(session.execute(select(DagTag.dag_id, DagTag.name)).all())
413414

414-
for row in session.query(DagModel.last_parsed_time).all():
415+
for row in session.execute(select(DagModel.last_parsed_time)).all():
415416
assert row[0] is not None
416417

417418
# Re-sync should do fewer queries
@@ -426,7 +427,7 @@ def test_bulk_write_to_db(self, testing_dag_bundle):
426427
SerializedDAG.bulk_write_to_db("testing", None, dags)
427428
with create_session() as session:
428429
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
429-
row[0] for row in session.query(DagModel.dag_id).all()
430+
row[0] for row in session.execute(select(DagModel.dag_id)).all()
430431
}
431432
assert {
432433
("dag-bulk-sync-0", "test-dag"),
@@ -437,24 +438,24 @@ def test_bulk_write_to_db(self, testing_dag_bundle):
437438
("dag-bulk-sync-2", "test-dag2"),
438439
("dag-bulk-sync-3", "test-dag"),
439440
("dag-bulk-sync-3", "test-dag2"),
440-
} == set(session.query(DagTag.dag_id, DagTag.name).all())
441+
} == set(session.execute(select(DagTag.dag_id, DagTag.name)).all())
441442
# Removing tags
442443
for dag in dags:
443444
dag.tags.remove("test-dag")
444445
with assert_queries_count(10):
445446
SerializedDAG.bulk_write_to_db("testing", None, dags)
446447
with create_session() as session:
447448
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
448-
row[0] for row in session.query(DagModel.dag_id).all()
449+
row[0] for row in session.execute(select(DagModel.dag_id)).all()
449450
}
450451
assert {
451452
("dag-bulk-sync-0", "test-dag2"),
452453
("dag-bulk-sync-1", "test-dag2"),
453454
("dag-bulk-sync-2", "test-dag2"),
454455
("dag-bulk-sync-3", "test-dag2"),
455-
} == set(session.query(DagTag.dag_id, DagTag.name).all())
456+
} == set(session.execute(select(DagTag.dag_id, DagTag.name)).all())
456457

457-
for row in session.query(DagModel.last_parsed_time).all():
458+
for row in session.execute(select(DagModel.last_parsed_time)).all():
458459
assert row[0] is not None
459460

460461
# Removing all tags
@@ -464,11 +465,11 @@ def test_bulk_write_to_db(self, testing_dag_bundle):
464465
SerializedDAG.bulk_write_to_db("testing", None, dags)
465466
with create_session() as session:
466467
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
467-
row[0] for row in session.query(DagModel.dag_id).all()
468+
row[0] for row in session.execute(select(DagModel.dag_id)).all()
468469
}
469-
assert not set(session.query(DagTag.dag_id, DagTag.name).all())
470+
assert not set(session.execute(select(DagTag.dag_id, DagTag.name)).all())
470471

471-
for row in session.query(DagModel.last_parsed_time).all():
472+
for row in session.execute(select(DagModel.last_parsed_time)).all():
472473
assert row[0] is not None
473474

474475
def test_bulk_write_to_db_single_dag(self, testing_dag_bundle):
@@ -486,12 +487,12 @@ def test_bulk_write_to_db_single_dag(self, testing_dag_bundle):
486487
with assert_queries_count(6):
487488
SerializedDAG.bulk_write_to_db("testing", None, dags)
488489
with create_session() as session:
489-
assert {"dag-bulk-sync-0"} == {row[0] for row in session.query(DagModel.dag_id).all()}
490+
assert {"dag-bulk-sync-0"} == {row[0] for row in session.execute(select(DagModel.dag_id)).all()}
490491
assert {
491492
("dag-bulk-sync-0", "test-dag"),
492-
} == set(session.query(DagTag.dag_id, DagTag.name).all())
493+
} == set(session.execute(select(DagTag.dag_id, DagTag.name)).all())
493494

494-
for row in session.query(DagModel.last_parsed_time).all():
495+
for row in session.execute(select(DagModel.last_parsed_time)).all():
495496
assert row[0] is not None
496497

497498
# Re-sync should do fewer queries
@@ -516,16 +517,16 @@ def test_bulk_write_to_db_multiple_dags(self, testing_dag_bundle):
516517
SerializedDAG.bulk_write_to_db("testing", None, dags)
517518
with create_session() as session:
518519
assert {"dag-bulk-sync-0", "dag-bulk-sync-1", "dag-bulk-sync-2", "dag-bulk-sync-3"} == {
519-
row[0] for row in session.query(DagModel.dag_id).all()
520+
row[0] for row in session.execute(select(DagModel.dag_id)).all()
520521
}
521522
assert {
522523
("dag-bulk-sync-0", "test-dag"),
523524
("dag-bulk-sync-1", "test-dag"),
524525
("dag-bulk-sync-2", "test-dag"),
525526
("dag-bulk-sync-3", "test-dag"),
526-
} == set(session.query(DagTag.dag_id, DagTag.name).all())
527+
} == set(session.execute(select(DagTag.dag_id, DagTag.name)).all())
527528

528-
for row in session.query(DagModel.last_parsed_time).all():
529+
for row in session.execute(select(DagModel.last_parsed_time)).all():
529530
assert row[0] is not None
530531

531532
# Re-sync should do fewer queries
@@ -682,21 +683,21 @@ def test_bulk_write_to_db_assets(self, testing_dag_bundle):
682683
create_scheduler_dag(dag1).clear()
683684
SerializedDAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session)
684685
session.commit()
685-
stored_assets = {x.uri: x for x in session.query(AssetModel).all()}
686+
stored_assets = {x.uri: x for x in session.scalars(select(AssetModel)).all()}
686687
asset1_orm = stored_assets[a1.uri]
687688
asset2_orm = stored_assets[a2.uri]
688689
asset3_orm = stored_assets[a3.uri]
689690
assert stored_assets[uri1].extra == {"should": "be used"}
690691
assert [x.dag_id for x in asset1_orm.scheduled_dags] == [dag_id1]
691692
assert [(x.task_id, x.dag_id) for x in asset1_orm.producing_tasks] == [(task_id, dag_id2)]
692693
assert set(
693-
session.query(
694-
TaskOutletAssetReference.task_id,
695-
TaskOutletAssetReference.dag_id,
696-
TaskOutletAssetReference.asset_id,
697-
)
698-
.filter(TaskOutletAssetReference.dag_id.in_((dag_id1, dag_id2)))
699-
.all()
694+
session.execute(
695+
select(
696+
TaskOutletAssetReference.task_id,
697+
TaskOutletAssetReference.dag_id,
698+
TaskOutletAssetReference.asset_id,
699+
).where(TaskOutletAssetReference.dag_id.in_((dag_id1, dag_id2)))
700+
).all()
700701
) == {
701702
(task_id, dag_id1, asset2_orm.id),
702703
(task_id, dag_id1, asset3_orm.id),
@@ -715,18 +716,18 @@ def test_bulk_write_to_db_assets(self, testing_dag_bundle):
715716
SerializedDAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session)
716717
session.commit()
717718
session.expunge_all()
718-
stored_assets = {x.uri: x for x in session.query(AssetModel).all()}
719+
stored_assets = {x.uri: x for x in session.scalars(select(AssetModel)).all()}
719720
asset1_orm = stored_assets[a1.uri]
720721
asset2_orm = stored_assets[a2.uri]
721722
assert [x.dag_id for x in asset1_orm.scheduled_dags] == []
722723
assert set(
723-
session.query(
724-
TaskOutletAssetReference.task_id,
725-
TaskOutletAssetReference.dag_id,
726-
TaskOutletAssetReference.asset_id,
727-
)
728-
.filter(TaskOutletAssetReference.dag_id.in_((dag_id1, dag_id2)))
729-
.all()
724+
session.execute(
725+
select(
726+
TaskOutletAssetReference.task_id,
727+
TaskOutletAssetReference.dag_id,
728+
TaskOutletAssetReference.asset_id,
729+
).where(TaskOutletAssetReference.dag_id.in_((dag_id1, dag_id2)))
730+
).all()
730731
) == {(task_id, dag_id1, asset2_orm.id)}
731732

732733
def test_bulk_write_to_db_asset_aliases(self, testing_dag_bundle):
@@ -749,7 +750,7 @@ def test_bulk_write_to_db_asset_aliases(self, testing_dag_bundle):
749750
SerializedDAG.bulk_write_to_db("testing", None, [dag1, dag2], session=session)
750751
session.commit()
751752

752-
stored_asset_alias_models = {x.name: x for x in session.query(AssetAliasModel).all()}
753+
stored_asset_alias_models = {x.name: x for x in session.scalars(select(AssetAliasModel)).all()}
753754
asset_alias_1_orm = stored_asset_alias_models[asset_alias_1.name]
754755
asset_alias_2_orm = stored_asset_alias_models[asset_alias_2.name]
755756
asset_alias_3_orm = stored_asset_alias_models[asset_alias_3.name]
@@ -818,7 +819,7 @@ def test_dag_is_deactivated_upon_dagfile_deletion(self, dag_maker):
818819
session = settings.Session()
819820
sync_dag_to_db(dag, session=session)
820821

821-
orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one()
822+
orm_dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id))
822823

823824
assert not orm_dag.is_stale
824825

@@ -827,10 +828,10 @@ def test_dag_is_deactivated_upon_dagfile_deletion(self, dag_maker):
827828
rel_filelocs=list_py_file_paths(settings.DAGS_FOLDER),
828829
)
829830

830-
orm_dag = session.query(DagModel).filter(DagModel.dag_id == dag_id).one()
831+
orm_dag = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id))
831832
assert orm_dag.is_stale
832833

833-
session.execute(DagModel.__table__.delete().where(DagModel.dag_id == dag_id))
834+
session.execute(delete(DagModel).where(DagModel.dag_id == dag_id))
834835
session.close()
835836

836837
def test_dag_naive_default_args_start_date_with_timezone(self):
@@ -1107,7 +1108,7 @@ def test_get_paused_dag_ids(self, testing_dag_bundle):
11071108
assert paused_dag_ids == {dag_id}
11081109

11091110
with create_session() as session:
1110-
session.query(DagModel).filter(DagModel.dag_id == dag_id).delete(synchronize_session=False)
1111+
session.execute(delete(DagModel).where(DagModel.dag_id == dag_id))
11111112

11121113
@pytest.mark.parametrize(
11131114
("schedule_arg", "expected_timetable", "interval_description"),
@@ -1293,7 +1294,7 @@ def consumer(value):
12931294
assert upstream_ti.state is None # cleared
12941295
assert ti.state is None # cleared
12951296
assert ti2.state == State.SUCCESS # not cleared
1296-
dagruns = session.query(DagRun).filter(DagRun.dag_id == dag_id).all()
1297+
dagruns = session.scalars(select(DagRun).where(DagRun.dag_id == dag_id)).all()
12971298

12981299
assert len(dagruns) == 1
12991300
dagrun: DagRun = dagruns[0]
@@ -1465,7 +1466,7 @@ def test_clear_dag(
14651466
session=session,
14661467
)
14671468

1468-
task_instances = session.query(TI).filter(TI.dag_id == dag_id).all()
1469+
task_instances = session.scalars(select(TI).where(TI.dag_id == dag_id)).all()
14691470

14701471
assert len(task_instances) == 1
14711472
task_instance: TI = task_instances[0]
@@ -1813,7 +1814,7 @@ def test_dag_owner_links(self, testing_dag_bundle):
18131814
dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE)
18141815
sync_dag_to_db(dag, session=session)
18151816

1816-
orm_dag_owners = session.query(DagOwnerAttributes).all()
1817+
orm_dag_owners = session.scalars(select(DagOwnerAttributes)).all()
18171818
assert not orm_dag_owners
18181819

18191820
@pytest.mark.need_serialized_dag
@@ -1963,7 +1964,7 @@ def test_dags_needing_dagruns_assets(self, dag_maker, session):
19631964
assert dag_models == []
19641965

19651966
# add queue records so we'll need a run
1966-
dag_model = session.query(DagModel).filter(DagModel.dag_id == dag.dag_id).one()
1967+
dag_model = session.scalar(select(DagModel).where(DagModel.dag_id == dag.dag_id))
19671968
asset_model: AssetModel = dag_model.schedule_assets[0]
19681969
session.add(AssetDagRunQueue(asset_id=asset_model.id, target_dag_id=dag_model.dag_id))
19691970
session.flush()
@@ -2250,7 +2251,7 @@ def test_dags_needing_dagruns_triggered_date_by_dag_queued_times(self, session,
22502251
EmptyOperator(task_id="task", outlets=[asset])
22512252
dr = dag_maker.create_dagrun()
22522253

2253-
asset_id = session.query(AssetModel.id).filter_by(uri=asset.uri).scalar()
2254+
asset_id = session.scalar(select(AssetModel.id).where(AssetModel.uri == asset.uri))
22542255

22552256
session.add(
22562257
AssetEvent(
@@ -2262,8 +2263,8 @@ def test_dags_needing_dagruns_triggered_date_by_dag_queued_times(self, session,
22622263
)
22632264
)
22642265

2265-
asset1_id = session.query(AssetModel.id).filter_by(uri=asset1.uri).scalar()
2266-
asset2_id = session.query(AssetModel.id).filter_by(uri=asset2.uri).scalar()
2266+
asset1_id = session.scalar(select(AssetModel.id).where(AssetModel.uri == asset1.uri))
2267+
asset2_id = session.scalar(select(AssetModel.id).where(AssetModel.uri == asset2.uri))
22672268

22682269
with dag_maker(dag_id="assets-consumer-multiple", schedule=[asset1, asset2]) as dag:
22692270
pass
@@ -2314,7 +2315,9 @@ def test_asset_expression(self, session: Session, testing_dag_bundle) -> None:
23142315
)
23152316
SerializedDAG.bulk_write_to_db("testing", None, [dag], session=session)
23162317

2317-
expression = session.scalars(select(DagModel.asset_expression).filter_by(dag_id=dag.dag_id)).one()
2318+
expression = session.scalars(
2319+
select(DagModel.asset_expression).where(DagModel.dag_id == dag.dag_id)
2320+
).one()
23182321
assert expression == {
23192322
"any": [
23202323
{
@@ -2506,14 +2509,12 @@ def test_set_task_instance_state(run_id, session, dag_maker):
25062509
)
25072510

25082511
def get_ti_from_db(task):
2509-
return (
2510-
session.query(TI)
2511-
.filter(
2512+
return session.scalar(
2513+
select(TI).where(
25122514
TI.dag_id == dag.dag_id,
25132515
TI.task_id == task.task_id,
25142516
TI.run_id == dagrun.run_id,
25152517
)
2516-
.one()
25172518
)
25182519

25192520
get_ti_from_db(task_1).state = State.FAILED
@@ -2588,16 +2589,16 @@ def consumer(value):
25882589
)
25892590
expand_mapped_task(mapped, dr2.run_id, "make_arg_lists", length=2, session=session)
25902591

2591-
session.query(TI).filter_by(dag_id=dag.dag_id).update({"state": TaskInstanceState.FAILED})
2592+
session.execute(update(TI).where(TI.dag_id == dag.dag_id).values(state=TaskInstanceState.FAILED))
25922593

25932594
ti_query = (
2594-
session.query(TI.task_id, TI.map_index, TI.run_id, TI.state)
2595-
.filter(TI.dag_id == dag.dag_id, TI.task_id.in_([task_id, "downstream"]))
2595+
select(TI.task_id, TI.map_index, TI.run_id, TI.state)
2596+
.where(TI.dag_id == dag.dag_id, TI.task_id.in_([task_id, "downstream"]))
25962597
.order_by(TI.run_id, TI.task_id, TI.map_index)
25972598
)
25982599

25992600
# Check pre-conditions
2600-
assert ti_query.all() == [
2601+
assert session.execute(ti_query).all() == [
26012602
("downstream", -1, dr1.run_id, TaskInstanceState.FAILED),
26022603
(task_id, 0, dr1.run_id, TaskInstanceState.FAILED),
26032604
(task_id, 1, dr1.run_id, TaskInstanceState.FAILED),
@@ -2616,7 +2617,7 @@ def consumer(value):
26162617
)
26172618
assert dr1 in session, "Check session is passed down all the way"
26182619

2619-
assert ti_query.all() == [
2620+
assert session.execute(ti_query).all() == [
26202621
("downstream", -1, dr1.run_id, None),
26212622
(task_id, 0, dr1.run_id, TaskInstanceState.FAILED),
26222623
(task_id, 1, dr1.run_id, TaskInstanceState.SUCCESS),
@@ -2867,7 +2868,7 @@ def test_get_asset_triggered_next_run_info(dag_maker, clear_assets):
28672868
dag3 = dag_maker.dag
28682869

28692870
session = dag_maker.session
2870-
asset1_id = session.query(AssetModel.id).filter_by(uri=asset1.uri).scalar()
2871+
asset1_id = session.scalar(select(AssetModel.id).where(AssetModel.uri == asset1.uri))
28712872
session.bulk_save_objects(
28722873
[
28732874
AssetDagRunQueue(asset_id=asset1_id, target_dag_id=dag2.dag_id),
@@ -2876,7 +2877,7 @@ def test_get_asset_triggered_next_run_info(dag_maker, clear_assets):
28762877
)
28772878
session.flush()
28782879

2879-
assets = session.query(AssetModel.uri).order_by(AssetModel.id).all()
2880+
assets = session.execute(select(AssetModel.uri).order_by(AssetModel.id)).all()
28802881

28812882
info = get_asset_triggered_next_run_info([dag1.dag_id], session=session)
28822883
assert info[dag1.dag_id] == {

0 commit comments

Comments
 (0)