Skip to content

Commit f798302

Browse files
authored
Get last run info one dag at a time (#59376)
This may make it easier when handling partition-driven dags differently from non-partition-driven.
1 parent d91b6c4 commit f798302

File tree

2 files changed

+35
-63
lines changed

2 files changed

+35
-63
lines changed

airflow-core/src/airflow/dag_processing/collection.py

Lines changed: 35 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@
7171
from collections.abc import Collection, Iterable, Iterator
7272

7373
from sqlalchemy.orm import Session
74-
from sqlalchemy.sql import Select, Subquery
74+
from sqlalchemy.sql import Select
7575

7676
from airflow.models.dagwarning import DagWarning
7777
from airflow.typing_compat import Self
@@ -96,69 +96,60 @@ def _create_orm_dags(
9696
yield orm_dag
9797

9898

99-
def _get_latest_runs_stmt(dag_ids: Collection[str]) -> Select:
99+
def _get_latest_runs_stmt(dag_id: str) -> Select:
100100
"""Build a select statement to retrieve the last automated run for each dag."""
101-
if len(dag_ids) == 1: # Index optimized fast path to avoid more complicated & slower groupby queryplan.
102-
(dag_id,) = dag_ids
103-
last_automated_runs_subq_scalar: Any = (
104-
select(func.max(DagRun.logical_date).label("max_logical_date"))
105-
.where(
106-
DagRun.dag_id == dag_id,
107-
DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)),
108-
)
109-
.scalar_subquery()
101+
max_logical_date = (
102+
select(func.max(DagRun.logical_date).label("max_logical_date"))
103+
.where(
104+
DagRun.dag_id == dag_id,
105+
DagRun.run_type.in_(
106+
(
107+
DagRunType.BACKFILL_JOB,
108+
DagRunType.SCHEDULED,
109+
)
110+
),
110111
)
111-
query = select(DagRun).where(
112+
.scalar_subquery()
113+
)
114+
return (
115+
select(DagRun)
116+
.where(
112117
DagRun.dag_id == dag_id,
113-
DagRun.logical_date == last_automated_runs_subq_scalar,
118+
DagRun.logical_date == max_logical_date,
114119
)
115-
else:
116-
last_automated_runs_subq_table: Subquery = (
117-
select(DagRun.dag_id, func.max(DagRun.logical_date).label("max_logical_date"))
118-
.where(
119-
DagRun.dag_id.in_(dag_ids),
120-
DagRun.run_type.in_((DagRunType.BACKFILL_JOB, DagRunType.SCHEDULED)),
120+
.options(
121+
load_only(
122+
DagRun.dag_id,
123+
DagRun.logical_date,
124+
DagRun.data_interval_start,
125+
DagRun.data_interval_end,
121126
)
122-
.group_by(DagRun.dag_id)
123-
.subquery()
124-
)
125-
query = select(DagRun).where(
126-
DagRun.dag_id == last_automated_runs_subq_table.c.dag_id,
127-
DagRun.logical_date == last_automated_runs_subq_table.c.max_logical_date,
128-
)
129-
return query.options(
130-
load_only(
131-
DagRun.dag_id,
132-
DagRun.logical_date,
133-
DagRun.data_interval_start,
134-
DagRun.data_interval_end,
135127
)
136128
)
137129

138130

139131
class _RunInfo(NamedTuple):
140-
latest_runs: dict[str, DagRun]
141-
num_active_runs: dict[str, int]
132+
latest_run: DagRun | None
133+
num_active_runs: int
142134

143135
@classmethod
144-
def calculate(cls, dags: dict[str, LazyDeserializedDAG], *, session: Session) -> Self:
136+
def calculate(cls, dag: LazyDeserializedDAG, *, session: Session) -> Self:
145137
"""
146138
Query the run counts from the db.
147139
148140
:param dags: dict of dags to query
149141
"""
150142
# Skip these queries entirely if no DAGs can be scheduled to save time.
151-
if not any(dag.timetable.can_be_scheduled for dag in dags.values()):
152-
return cls({}, {})
143+
if not dag.timetable.can_be_scheduled:
144+
return cls(None, 0)
153145

154-
latest_runs = {run.dag_id: run for run in session.scalars(_get_latest_runs_stmt(dag_ids=dags.keys()))}
146+
latest_run = session.scalar(_get_latest_runs_stmt(dag_id=dag.dag_id))
155147
active_run_counts = DagRun.active_runs_of_dags(
156-
dag_ids=dags.keys(),
148+
dag_ids=[dag.dag_id],
157149
exclude_backfill=True,
158150
session=session,
159151
)
160-
161-
return cls(latest_runs, active_run_counts)
152+
return cls(latest_run, active_run_counts.get(dag.dag_id, 0))
162153

163154

164155
def _update_dag_tags(tag_names: set[str], dm: DagModel, *, session: Session) -> None:
@@ -491,8 +482,8 @@ def update_dags(
491482
session: Session,
492483
) -> None:
493484
# we exclude backfill from active run counts since their concurrency is separate
494-
run_info = _RunInfo.calculate(dags=self.dags, session=session)
495485
for dag_id, dm in sorted(orm_dags.items()):
486+
run_info = _RunInfo.calculate(dag=self.dags[dag_id], session=session)
496487
dag = self.dags[dag_id]
497488
dm.fileloc = dag.fileloc
498489
dm.relative_fileloc = dag.relative_fileloc
@@ -547,12 +538,12 @@ def update_dags(
547538
dm.bundle_name = self.bundle_name
548539
dm.bundle_version = self.bundle_version
549540

550-
last_automated_run: DagRun | None = run_info.latest_runs.get(dag.dag_id)
541+
last_automated_run: DagRun | None = run_info.latest_run
551542
if last_automated_run is None:
552543
last_automated_data_interval = None
553544
else:
554545
last_automated_data_interval = get_run_data_interval(dag.timetable, last_automated_run)
555-
if run_info.num_active_runs.get(dag.dag_id, 0) >= dm.max_active_runs:
546+
if run_info.num_active_runs >= dm.max_active_runs:
556547
dm.next_dagrun_create_after = None
557548
else:
558549
dm.calculate_dagrun_date_fields(dag, last_automated_data_interval) # type: ignore[arg-type]

airflow-core/tests/unit/dag_processing/test_collection.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -96,25 +96,6 @@ def test_statement_latest_runs_one_dag():
9696
assert actual == expected, compiled_stmt
9797

9898

99-
def test_statement_latest_runs_many_dag():
100-
with warnings.catch_warnings():
101-
warnings.simplefilter("error", category=SAWarning)
102-
103-
stmt = _get_latest_runs_stmt(["fake-dag-1", "fake-dag-2"])
104-
compiled_stmt = str(stmt.compile())
105-
actual = [x.strip() for x in compiled_stmt.splitlines()]
106-
expected = [
107-
"SELECT dag_run.id, dag_run.dag_id, dag_run.logical_date, "
108-
"dag_run.data_interval_start, dag_run.data_interval_end",
109-
"FROM dag_run, (SELECT dag_run.dag_id AS dag_id, max(dag_run.logical_date) AS max_logical_date",
110-
"FROM dag_run",
111-
"WHERE dag_run.dag_id IN (__[POSTCOMPILE_dag_id_1]) "
112-
"AND dag_run.run_type IN (__[POSTCOMPILE_run_type_1]) GROUP BY dag_run.dag_id) AS anon_1",
113-
"WHERE dag_run.dag_id = anon_1.dag_id AND dag_run.logical_date = anon_1.max_logical_date",
114-
]
115-
assert actual == expected, compiled_stmt
116-
117-
11899
@pytest.mark.db_test
119100
class TestAssetModelOperation:
120101
@staticmethod

0 commit comments

Comments
 (0)