7171 from collections .abc import Collection , Iterable , Iterator
7272
7373 from sqlalchemy .orm import Session
74- from sqlalchemy .sql import Select , Subquery
74+ from sqlalchemy .sql import Select
7575
7676 from airflow .models .dagwarning import DagWarning
7777 from airflow .typing_compat import Self
@@ -96,69 +96,60 @@ def _create_orm_dags(
9696 yield orm_dag
9797
9898
99- def _get_latest_runs_stmt (dag_ids : Collection [ str ] ) -> Select :
99+ def _get_latest_runs_stmt (dag_id : str ) -> Select :
100100 """Build a select statement to retrieve the last automated run for each dag."""
101- if len (dag_ids ) == 1 : # Index optimized fast path to avoid more complicated & slower groupby queryplan.
102- (dag_id ,) = dag_ids
103- last_automated_runs_subq_scalar : Any = (
104- select (func .max (DagRun .logical_date ).label ("max_logical_date" ))
105- .where (
106- DagRun .dag_id == dag_id ,
107- DagRun .run_type .in_ ((DagRunType .BACKFILL_JOB , DagRunType .SCHEDULED )),
108- )
109- .scalar_subquery ()
101+ max_logical_date = (
102+ select (func .max (DagRun .logical_date ).label ("max_logical_date" ))
103+ .where (
104+ DagRun .dag_id == dag_id ,
105+ DagRun .run_type .in_ (
106+ (
107+ DagRunType .BACKFILL_JOB ,
108+ DagRunType .SCHEDULED ,
109+ )
110+ ),
110111 )
111- query = select (DagRun ).where (
112+ .scalar_subquery ()
113+ )
114+ return (
115+ select (DagRun )
116+ .where (
112117 DagRun .dag_id == dag_id ,
113- DagRun .logical_date == last_automated_runs_subq_scalar ,
118+ DagRun .logical_date == max_logical_date ,
114119 )
115- else :
116- last_automated_runs_subq_table : Subquery = (
117- select ( DagRun .dag_id , func . max ( DagRun . logical_date ). label ( "max_logical_date" ))
118- . where (
119- DagRun .dag_id . in_ ( dag_ids ) ,
120- DagRun .run_type . in_ (( DagRunType . BACKFILL_JOB , DagRunType . SCHEDULED )) ,
120+ . options (
121+ load_only (
122+ DagRun .dag_id ,
123+ DagRun . logical_date ,
124+ DagRun .data_interval_start ,
125+ DagRun .data_interval_end ,
121126 )
122- .group_by (DagRun .dag_id )
123- .subquery ()
124- )
125- query = select (DagRun ).where (
126- DagRun .dag_id == last_automated_runs_subq_table .c .dag_id ,
127- DagRun .logical_date == last_automated_runs_subq_table .c .max_logical_date ,
128- )
129- return query .options (
130- load_only (
131- DagRun .dag_id ,
132- DagRun .logical_date ,
133- DagRun .data_interval_start ,
134- DagRun .data_interval_end ,
135127 )
136128 )
137129
138130
139131class _RunInfo (NamedTuple ):
140- latest_runs : dict [ str , DagRun ]
141- num_active_runs : dict [ str , int ]
132+ latest_run : DagRun | None
133+ num_active_runs : int
142134
143135 @classmethod
144- def calculate (cls , dags : dict [ str , LazyDeserializedDAG ] , * , session : Session ) -> Self :
136+ def calculate (cls , dag : LazyDeserializedDAG , * , session : Session ) -> Self :
145137 """
146138 Query the run counts from the db.
147139
148140 :param dags: dict of dags to query
149141 """
150142 # Skip these queries entirely if no DAGs can be scheduled to save time.
151- if not any ( dag .timetable .can_be_scheduled for dag in dags . values ()) :
152- return cls ({}, {} )
143+ if not dag .timetable .can_be_scheduled :
144+ return cls (None , 0 )
153145
154- latest_runs = { run . dag_id : run for run in session .scalars (_get_latest_runs_stmt (dag_ids = dags . keys ()))}
146+ latest_run = session .scalar (_get_latest_runs_stmt (dag_id = dag . dag_id ))
155147 active_run_counts = DagRun .active_runs_of_dags (
156- dag_ids = dags . keys () ,
148+ dag_ids = [ dag . dag_id ] ,
157149 exclude_backfill = True ,
158150 session = session ,
159151 )
160-
161- return cls (latest_runs , active_run_counts )
152+ return cls (latest_run , active_run_counts .get (dag .dag_id , 0 ))
162153
163154
164155def _update_dag_tags (tag_names : set [str ], dm : DagModel , * , session : Session ) -> None :
@@ -491,8 +482,8 @@ def update_dags(
491482 session : Session ,
492483 ) -> None :
493484 # we exclude backfill from active run counts since their concurrency is separate
494- run_info = _RunInfo .calculate (dags = self .dags , session = session )
495485 for dag_id , dm in sorted (orm_dags .items ()):
486+ run_info = _RunInfo .calculate (dag = self .dags [dag_id ], session = session )
496487 dag = self .dags [dag_id ]
497488 dm .fileloc = dag .fileloc
498489 dm .relative_fileloc = dag .relative_fileloc
@@ -547,12 +538,12 @@ def update_dags(
547538 dm .bundle_name = self .bundle_name
548539 dm .bundle_version = self .bundle_version
549540
550- last_automated_run : DagRun | None = run_info .latest_runs . get ( dag . dag_id )
541+ last_automated_run : DagRun | None = run_info .latest_run
551542 if last_automated_run is None :
552543 last_automated_data_interval = None
553544 else :
554545 last_automated_data_interval = get_run_data_interval (dag .timetable , last_automated_run )
555- if run_info .num_active_runs . get ( dag . dag_id , 0 ) >= dm .max_active_runs :
546+ if run_info .num_active_runs >= dm .max_active_runs :
556547 dm .next_dagrun_create_after = None
557548 else :
558549 dm .calculate_dagrun_date_fields (dag , last_automated_data_interval ) # type: ignore[arg-type]
0 commit comments