3333import pendulum
3434import pytest
3535import time_machine
36- from sqlalchemy import inspect , select
36+ from sqlalchemy import delete , inspect , select , update
3737
3838from airflow import settings
3939from airflow ._shared .module_loading import qualname
@@ -386,7 +386,8 @@ def test_dagtag_repr(self, testing_dag_bundle):
386386 sync_dag_to_db (dag )
387387 with create_session () as session :
388388 assert {"tag-1" , "tag-2" } == {
389- repr (t ) for t in session .query (DagTag ).filter (DagTag .dag_id == "dag-test-dagtag" ).all ()
389+ repr (t )
390+ for t in session .scalars (select (DagTag ).where (DagTag .dag_id == "dag-test-dagtag" )).all ()
390391 }
391392
392393 def test_bulk_write_to_db (self , testing_dag_bundle ):
@@ -402,16 +403,16 @@ def test_bulk_write_to_db(self, testing_dag_bundle):
402403 SerializedDAG .bulk_write_to_db ("testing" , None , dags )
403404 with create_session () as session :
404405 assert {"dag-bulk-sync-0" , "dag-bulk-sync-1" , "dag-bulk-sync-2" , "dag-bulk-sync-3" } == {
405- row [0 ] for row in session .query ( DagModel .dag_id ).all ()
406+ row [0 ] for row in session .execute ( select ( DagModel .dag_id ) ).all ()
406407 }
407408 assert {
408409 ("dag-bulk-sync-0" , "test-dag" ),
409410 ("dag-bulk-sync-1" , "test-dag" ),
410411 ("dag-bulk-sync-2" , "test-dag" ),
411412 ("dag-bulk-sync-3" , "test-dag" ),
412- } == set (session .query ( DagTag .dag_id , DagTag .name ).all ())
413+ } == set (session .execute ( select ( DagTag .dag_id , DagTag .name ) ).all ())
413414
414- for row in session .query ( DagModel .last_parsed_time ).all ():
415+ for row in session .execute ( select ( DagModel .last_parsed_time ) ).all ():
415416 assert row [0 ] is not None
416417
417418 # Re-sync should do fewer queries
@@ -426,7 +427,7 @@ def test_bulk_write_to_db(self, testing_dag_bundle):
426427 SerializedDAG .bulk_write_to_db ("testing" , None , dags )
427428 with create_session () as session :
428429 assert {"dag-bulk-sync-0" , "dag-bulk-sync-1" , "dag-bulk-sync-2" , "dag-bulk-sync-3" } == {
429- row [0 ] for row in session .query ( DagModel .dag_id ).all ()
430+ row [0 ] for row in session .execute ( select ( DagModel .dag_id ) ).all ()
430431 }
431432 assert {
432433 ("dag-bulk-sync-0" , "test-dag" ),
@@ -437,24 +438,24 @@ def test_bulk_write_to_db(self, testing_dag_bundle):
437438 ("dag-bulk-sync-2" , "test-dag2" ),
438439 ("dag-bulk-sync-3" , "test-dag" ),
439440 ("dag-bulk-sync-3" , "test-dag2" ),
440- } == set (session .query ( DagTag .dag_id , DagTag .name ).all ())
441+ } == set (session .execute ( select ( DagTag .dag_id , DagTag .name ) ).all ())
441442 # Removing tags
442443 for dag in dags :
443444 dag .tags .remove ("test-dag" )
444445 with assert_queries_count (10 ):
445446 SerializedDAG .bulk_write_to_db ("testing" , None , dags )
446447 with create_session () as session :
447448 assert {"dag-bulk-sync-0" , "dag-bulk-sync-1" , "dag-bulk-sync-2" , "dag-bulk-sync-3" } == {
448- row [0 ] for row in session .query ( DagModel .dag_id ).all ()
449+ row [0 ] for row in session .execute ( select ( DagModel .dag_id ) ).all ()
449450 }
450451 assert {
451452 ("dag-bulk-sync-0" , "test-dag2" ),
452453 ("dag-bulk-sync-1" , "test-dag2" ),
453454 ("dag-bulk-sync-2" , "test-dag2" ),
454455 ("dag-bulk-sync-3" , "test-dag2" ),
455- } == set (session .query ( DagTag .dag_id , DagTag .name ).all ())
456+ } == set (session .execute ( select ( DagTag .dag_id , DagTag .name ) ).all ())
456457
457- for row in session .query ( DagModel .last_parsed_time ).all ():
458+ for row in session .execute ( select ( DagModel .last_parsed_time ) ).all ():
458459 assert row [0 ] is not None
459460
460461 # Removing all tags
@@ -464,11 +465,11 @@ def test_bulk_write_to_db(self, testing_dag_bundle):
464465 SerializedDAG .bulk_write_to_db ("testing" , None , dags )
465466 with create_session () as session :
466467 assert {"dag-bulk-sync-0" , "dag-bulk-sync-1" , "dag-bulk-sync-2" , "dag-bulk-sync-3" } == {
467- row [0 ] for row in session .query ( DagModel .dag_id ).all ()
468+ row [0 ] for row in session .execute ( select ( DagModel .dag_id ) ).all ()
468469 }
469- assert not set (session .query ( DagTag .dag_id , DagTag .name ).all ())
470+ assert not set (session .execute ( select ( DagTag .dag_id , DagTag .name ) ).all ())
470471
471- for row in session .query ( DagModel .last_parsed_time ).all ():
472+ for row in session .execute ( select ( DagModel .last_parsed_time ) ).all ():
472473 assert row [0 ] is not None
473474
474475 def test_bulk_write_to_db_single_dag (self , testing_dag_bundle ):
@@ -486,12 +487,12 @@ def test_bulk_write_to_db_single_dag(self, testing_dag_bundle):
486487 with assert_queries_count (6 ):
487488 SerializedDAG .bulk_write_to_db ("testing" , None , dags )
488489 with create_session () as session :
489- assert {"dag-bulk-sync-0" } == {row [0 ] for row in session .query ( DagModel .dag_id ).all ()}
490+ assert {"dag-bulk-sync-0" } == {row [0 ] for row in session .execute ( select ( DagModel .dag_id ) ).all ()}
490491 assert {
491492 ("dag-bulk-sync-0" , "test-dag" ),
492- } == set (session .query ( DagTag .dag_id , DagTag .name ).all ())
493+ } == set (session .execute ( select ( DagTag .dag_id , DagTag .name ) ).all ())
493494
494- for row in session .query ( DagModel .last_parsed_time ).all ():
495+ for row in session .execute ( select ( DagModel .last_parsed_time ) ).all ():
495496 assert row [0 ] is not None
496497
497498 # Re-sync should do fewer queries
@@ -516,16 +517,16 @@ def test_bulk_write_to_db_multiple_dags(self, testing_dag_bundle):
516517 SerializedDAG .bulk_write_to_db ("testing" , None , dags )
517518 with create_session () as session :
518519 assert {"dag-bulk-sync-0" , "dag-bulk-sync-1" , "dag-bulk-sync-2" , "dag-bulk-sync-3" } == {
519- row [0 ] for row in session .query ( DagModel .dag_id ).all ()
520+ row [0 ] for row in session .execute ( select ( DagModel .dag_id ) ).all ()
520521 }
521522 assert {
522523 ("dag-bulk-sync-0" , "test-dag" ),
523524 ("dag-bulk-sync-1" , "test-dag" ),
524525 ("dag-bulk-sync-2" , "test-dag" ),
525526 ("dag-bulk-sync-3" , "test-dag" ),
526- } == set (session .query ( DagTag .dag_id , DagTag .name ).all ())
527+ } == set (session .execute ( select ( DagTag .dag_id , DagTag .name ) ).all ())
527528
528- for row in session .query ( DagModel .last_parsed_time ).all ():
529+ for row in session .execute ( select ( DagModel .last_parsed_time ) ).all ():
529530 assert row [0 ] is not None
530531
531532 # Re-sync should do fewer queries
@@ -682,21 +683,21 @@ def test_bulk_write_to_db_assets(self, testing_dag_bundle):
682683 create_scheduler_dag (dag1 ).clear ()
683684 SerializedDAG .bulk_write_to_db ("testing" , None , [dag1 , dag2 ], session = session )
684685 session .commit ()
685- stored_assets = {x .uri : x for x in session .query ( AssetModel ).all ()}
686+ stored_assets = {x .uri : x for x in session .scalars ( select ( AssetModel ) ).all ()}
686687 asset1_orm = stored_assets [a1 .uri ]
687688 asset2_orm = stored_assets [a2 .uri ]
688689 asset3_orm = stored_assets [a3 .uri ]
689690 assert stored_assets [uri1 ].extra == {"should" : "be used" }
690691 assert [x .dag_id for x in asset1_orm .scheduled_dags ] == [dag_id1 ]
691692 assert [(x .task_id , x .dag_id ) for x in asset1_orm .producing_tasks ] == [(task_id , dag_id2 )]
692693 assert set (
693- session .query (
694- TaskOutletAssetReference . task_id ,
695- TaskOutletAssetReference .dag_id ,
696- TaskOutletAssetReference .asset_id ,
697- )
698- . filter (TaskOutletAssetReference .dag_id .in_ ((dag_id1 , dag_id2 )))
699- .all ()
694+ session .execute (
695+ select (
696+ TaskOutletAssetReference .task_id ,
697+ TaskOutletAssetReference .dag_id ,
698+ TaskOutletAssetReference . asset_id ,
699+ ). where (TaskOutletAssetReference .dag_id .in_ ((dag_id1 , dag_id2 )))
700+ ) .all ()
700701 ) == {
701702 (task_id , dag_id1 , asset2_orm .id ),
702703 (task_id , dag_id1 , asset3_orm .id ),
@@ -715,18 +716,18 @@ def test_bulk_write_to_db_assets(self, testing_dag_bundle):
715716 SerializedDAG .bulk_write_to_db ("testing" , None , [dag1 , dag2 ], session = session )
716717 session .commit ()
717718 session .expunge_all ()
718- stored_assets = {x .uri : x for x in session .query ( AssetModel ).all ()}
719+ stored_assets = {x .uri : x for x in session .scalars ( select ( AssetModel ) ).all ()}
719720 asset1_orm = stored_assets [a1 .uri ]
720721 asset2_orm = stored_assets [a2 .uri ]
721722 assert [x .dag_id for x in asset1_orm .scheduled_dags ] == []
722723 assert set (
723- session .query (
724- TaskOutletAssetReference . task_id ,
725- TaskOutletAssetReference .dag_id ,
726- TaskOutletAssetReference .asset_id ,
727- )
728- . filter (TaskOutletAssetReference .dag_id .in_ ((dag_id1 , dag_id2 )))
729- .all ()
724+ session .execute (
725+ select (
726+ TaskOutletAssetReference .task_id ,
727+ TaskOutletAssetReference .dag_id ,
728+ TaskOutletAssetReference . asset_id ,
729+ ). where (TaskOutletAssetReference .dag_id .in_ ((dag_id1 , dag_id2 )))
730+ ) .all ()
730731 ) == {(task_id , dag_id1 , asset2_orm .id )}
731732
732733 def test_bulk_write_to_db_asset_aliases (self , testing_dag_bundle ):
@@ -749,7 +750,7 @@ def test_bulk_write_to_db_asset_aliases(self, testing_dag_bundle):
749750 SerializedDAG .bulk_write_to_db ("testing" , None , [dag1 , dag2 ], session = session )
750751 session .commit ()
751752
752- stored_asset_alias_models = {x .name : x for x in session .query ( AssetAliasModel ).all ()}
753+ stored_asset_alias_models = {x .name : x for x in session .scalars ( select ( AssetAliasModel ) ).all ()}
753754 asset_alias_1_orm = stored_asset_alias_models [asset_alias_1 .name ]
754755 asset_alias_2_orm = stored_asset_alias_models [asset_alias_2 .name ]
755756 asset_alias_3_orm = stored_asset_alias_models [asset_alias_3 .name ]
@@ -818,7 +819,7 @@ def test_dag_is_deactivated_upon_dagfile_deletion(self, dag_maker):
818819 session = settings .Session ()
819820 sync_dag_to_db (dag , session = session )
820821
821- orm_dag = session .query ( DagModel ).filter (DagModel .dag_id == dag_id ). one ( )
822+ orm_dag = session .scalar ( select ( DagModel ).where (DagModel .dag_id == dag_id ))
822823
823824 assert not orm_dag .is_stale
824825
@@ -827,10 +828,10 @@ def test_dag_is_deactivated_upon_dagfile_deletion(self, dag_maker):
827828 rel_filelocs = list_py_file_paths (settings .DAGS_FOLDER ),
828829 )
829830
830- orm_dag = session .query ( DagModel ).filter (DagModel .dag_id == dag_id ). one ( )
831+ orm_dag = session .scalar ( select ( DagModel ).where (DagModel .dag_id == dag_id ))
831832 assert orm_dag .is_stale
832833
833- session .execute (DagModel . __table__ . delete ().where (DagModel .dag_id == dag_id ))
834+ session .execute (delete (DagModel ).where (DagModel .dag_id == dag_id ))
834835 session .close ()
835836
836837 def test_dag_naive_default_args_start_date_with_timezone (self ):
@@ -1107,7 +1108,7 @@ def test_get_paused_dag_ids(self, testing_dag_bundle):
11071108 assert paused_dag_ids == {dag_id }
11081109
11091110 with create_session () as session :
1110- session .query ( DagModel ).filter (DagModel .dag_id == dag_id ). delete ( synchronize_session = False )
1111+ session .execute ( delete ( DagModel ).where (DagModel .dag_id == dag_id ))
11111112
11121113 @pytest .mark .parametrize (
11131114 ("schedule_arg" , "expected_timetable" , "interval_description" ),
@@ -1293,7 +1294,7 @@ def consumer(value):
12931294 assert upstream_ti .state is None # cleared
12941295 assert ti .state is None # cleared
12951296 assert ti2 .state == State .SUCCESS # not cleared
1296- dagruns = session .query ( DagRun ).filter (DagRun .dag_id == dag_id ).all ()
1297+ dagruns = session .scalars ( select ( DagRun ).where (DagRun .dag_id == dag_id ) ).all ()
12971298
12981299 assert len (dagruns ) == 1
12991300 dagrun : DagRun = dagruns [0 ]
@@ -1465,7 +1466,7 @@ def test_clear_dag(
14651466 session = session ,
14661467 )
14671468
1468- task_instances = session .query ( TI ).filter (TI .dag_id == dag_id ).all ()
1469+ task_instances = session .scalars ( select ( TI ).where (TI .dag_id == dag_id ) ).all ()
14691470
14701471 assert len (task_instances ) == 1
14711472 task_instance : TI = task_instances [0 ]
@@ -1813,7 +1814,7 @@ def test_dag_owner_links(self, testing_dag_bundle):
18131814 dag = DAG ("dag" , schedule = None , start_date = DEFAULT_DATE )
18141815 sync_dag_to_db (dag , session = session )
18151816
1816- orm_dag_owners = session .query ( DagOwnerAttributes ).all ()
1817+ orm_dag_owners = session .scalars ( select ( DagOwnerAttributes ) ).all ()
18171818 assert not orm_dag_owners
18181819
18191820 @pytest .mark .need_serialized_dag
@@ -1963,7 +1964,7 @@ def test_dags_needing_dagruns_assets(self, dag_maker, session):
19631964 assert dag_models == []
19641965
19651966 # add queue records so we'll need a run
1966- dag_model = session .query ( DagModel ).filter (DagModel .dag_id == dag .dag_id ). one ( )
1967+ dag_model = session .scalar ( select ( DagModel ).where (DagModel .dag_id == dag .dag_id ))
19671968 asset_model : AssetModel = dag_model .schedule_assets [0 ]
19681969 session .add (AssetDagRunQueue (asset_id = asset_model .id , target_dag_id = dag_model .dag_id ))
19691970 session .flush ()
@@ -2250,7 +2251,7 @@ def test_dags_needing_dagruns_triggered_date_by_dag_queued_times(self, session,
22502251 EmptyOperator (task_id = "task" , outlets = [asset ])
22512252 dr = dag_maker .create_dagrun ()
22522253
2253- asset_id = session .query ( AssetModel .id ).filter_by ( uri = asset .uri ). scalar ( )
2254+ asset_id = session .scalar ( select ( AssetModel .id ).where ( AssetModel . uri == asset .uri ))
22542255
22552256 session .add (
22562257 AssetEvent (
@@ -2262,8 +2263,8 @@ def test_dags_needing_dagruns_triggered_date_by_dag_queued_times(self, session,
22622263 )
22632264 )
22642265
2265- asset1_id = session .query ( AssetModel .id ).filter_by ( uri = asset1 .uri ). scalar ( )
2266- asset2_id = session .query ( AssetModel .id ).filter_by ( uri = asset2 .uri ). scalar ( )
2266+ asset1_id = session .scalar ( select ( AssetModel .id ).where ( AssetModel . uri == asset1 .uri ))
2267+ asset2_id = session .scalar ( select ( AssetModel .id ).where ( AssetModel . uri == asset2 .uri ))
22672268
22682269 with dag_maker (dag_id = "assets-consumer-multiple" , schedule = [asset1 , asset2 ]) as dag :
22692270 pass
@@ -2314,7 +2315,9 @@ def test_asset_expression(self, session: Session, testing_dag_bundle) -> None:
23142315 )
23152316 SerializedDAG .bulk_write_to_db ("testing" , None , [dag ], session = session )
23162317
2317- expression = session .scalars (select (DagModel .asset_expression ).filter_by (dag_id = dag .dag_id )).one ()
2318+ expression = session .scalars (
2319+ select (DagModel .asset_expression ).where (DagModel .dag_id == dag .dag_id )
2320+ ).one ()
23182321 assert expression == {
23192322 "any" : [
23202323 {
@@ -2506,14 +2509,12 @@ def test_set_task_instance_state(run_id, session, dag_maker):
25062509 )
25072510
25082511 def get_ti_from_db (task ):
2509- return (
2510- session .query (TI )
2511- .filter (
2512+ return session .scalar (
2513+ select (TI ).where (
25122514 TI .dag_id == dag .dag_id ,
25132515 TI .task_id == task .task_id ,
25142516 TI .run_id == dagrun .run_id ,
25152517 )
2516- .one ()
25172518 )
25182519
25192520 get_ti_from_db (task_1 ).state = State .FAILED
@@ -2588,16 +2589,16 @@ def consumer(value):
25882589 )
25892590 expand_mapped_task (mapped , dr2 .run_id , "make_arg_lists" , length = 2 , session = session )
25902591
2591- session .query ( TI ).filter_by ( dag_id = dag .dag_id ).update ({ " state" : TaskInstanceState .FAILED } )
2592+ session .execute ( update ( TI ).where ( TI . dag_id == dag .dag_id ).values ( state = TaskInstanceState .FAILED ) )
25922593
25932594 ti_query = (
2594- session . query (TI .task_id , TI .map_index , TI .run_id , TI .state )
2595- .filter (TI .dag_id == dag .dag_id , TI .task_id .in_ ([task_id , "downstream" ]))
2595+ select (TI .task_id , TI .map_index , TI .run_id , TI .state )
2596+ .where (TI .dag_id == dag .dag_id , TI .task_id .in_ ([task_id , "downstream" ]))
25962597 .order_by (TI .run_id , TI .task_id , TI .map_index )
25972598 )
25982599
25992600 # Check pre-conditions
2600- assert ti_query .all () == [
2601+ assert session . execute ( ti_query ) .all () == [
26012602 ("downstream" , - 1 , dr1 .run_id , TaskInstanceState .FAILED ),
26022603 (task_id , 0 , dr1 .run_id , TaskInstanceState .FAILED ),
26032604 (task_id , 1 , dr1 .run_id , TaskInstanceState .FAILED ),
@@ -2616,7 +2617,7 @@ def consumer(value):
26162617 )
26172618 assert dr1 in session , "Check session is passed down all the way"
26182619
2619- assert ti_query .all () == [
2620+ assert session . execute ( ti_query ) .all () == [
26202621 ("downstream" , - 1 , dr1 .run_id , None ),
26212622 (task_id , 0 , dr1 .run_id , TaskInstanceState .FAILED ),
26222623 (task_id , 1 , dr1 .run_id , TaskInstanceState .SUCCESS ),
@@ -2867,7 +2868,7 @@ def test_get_asset_triggered_next_run_info(dag_maker, clear_assets):
28672868 dag3 = dag_maker .dag
28682869
28692870 session = dag_maker .session
2870- asset1_id = session .query ( AssetModel .id ).filter_by ( uri = asset1 .uri ). scalar ( )
2871+ asset1_id = session .scalar ( select ( AssetModel .id ).where ( AssetModel . uri == asset1 .uri ))
28712872 session .bulk_save_objects (
28722873 [
28732874 AssetDagRunQueue (asset_id = asset1_id , target_dag_id = dag2 .dag_id ),
@@ -2876,7 +2877,7 @@ def test_get_asset_triggered_next_run_info(dag_maker, clear_assets):
28762877 )
28772878 session .flush ()
28782879
2879- assets = session .query ( AssetModel .uri ).order_by (AssetModel .id ).all ()
2880+ assets = session .execute ( select ( AssetModel .uri ).order_by (AssetModel .id ) ).all ()
28802881
28812882 info = get_asset_triggered_next_run_info ([dag1 .dag_id ], session = session )
28822883 assert info [dag1 .dag_id ] == {
0 commit comments