Skip to content

Commit 5a28c44

Browse files
authored
Fix n+1 query to fetch tags in the dags list page (apache#57270)
1 parent 63a5cfe commit 5a28c44

File tree

3 files changed

+135
-1
lines changed

3 files changed

+135
-1
lines changed

airflow-core/src/airflow/api_fastapi/common/db/dags.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from typing import TYPE_CHECKING
2121

2222
from sqlalchemy import func, select
23+
from sqlalchemy.orm import selectinload
2324

2425
from airflow.api_fastapi.common.db.common import (
2526
apply_filters_to_select,
@@ -33,7 +34,7 @@
3334

3435

3536
def generate_dag_with_latest_run_query(max_run_filters: list[BaseParam], order_by: SortParam) -> Select:
36-
query = select(DagModel)
37+
query = select(DagModel).options(selectinload(DagModel.tags))
3738

3839
max_run_id_query = ( # ordering by id will not always be "latest run", but it's a simplifying assumption
3940
select(DagRun.dag_id, func.max(DagRun.id).label("max_dag_run_id"))

airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dags.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from airflow.utils.state import DagRunState, TaskInstanceState
3232
from airflow.utils.types import DagRunTriggeredByType, DagRunType
3333

34+
from tests_common.test_utils.asserts import count_queries
3435
from tests_common.test_utils.db import (
3536
clear_db_assets,
3637
clear_db_connections,
@@ -524,6 +525,71 @@ def test_get_dags_filter_has_import_errors(self, session, test_client, filter_va
524525
assert body["total_entries"] == 1
525526
assert [dag["dag_id"] for dag in body["dags"]] == expected_ids
526527

528+
def test_get_dags_no_n_plus_one_queries(self, session, test_client):
529+
"""Test that fetching DAGs with tags doesn't trigger n+1 queries."""
530+
num_dags = 5
531+
for i in range(num_dags):
532+
dag_id = f"test_dag_queries_{i}"
533+
dag_model = DagModel(
534+
dag_id=dag_id,
535+
bundle_name="dag_maker",
536+
fileloc=f"/tmp/{dag_id}.py",
537+
is_stale=False,
538+
)
539+
session.add(dag_model)
540+
session.flush()
541+
542+
for j in range(3):
543+
tag = DagTag(name=f"tag_{i}_{j}", dag_id=dag_id)
544+
session.add(tag)
545+
546+
session.commit()
547+
session.expire_all()
548+
549+
with count_queries() as result:
550+
response = test_client.get("/dags", params={"limit": 10})
551+
552+
assert response.status_code == 200
553+
body = response.json()
554+
dags_with_our_prefix = [d for d in body["dags"] if d["dag_id"].startswith("test_dag_queries_")]
555+
assert len(dags_with_our_prefix) == num_dags
556+
for dag in dags_with_our_prefix:
557+
assert len(dag["tags"]) == 3
558+
559+
first_query_count = sum(result.values())
560+
561+
# Add more DAGs and verify query count doesn't scale linearly
562+
for i in range(num_dags, num_dags + 3):
563+
dag_id = f"test_dag_queries_{i}"
564+
dag_model = DagModel(
565+
dag_id=dag_id,
566+
bundle_name="dag_maker",
567+
fileloc=f"/tmp/{dag_id}.py",
568+
is_stale=False,
569+
)
570+
session.add(dag_model)
571+
session.flush()
572+
573+
for j in range(3):
574+
tag = DagTag(name=f"tag_{i}_{j}", dag_id=dag_id)
575+
session.add(tag)
576+
577+
session.commit()
578+
session.expire_all()
579+
580+
with count_queries() as result2:
581+
response = test_client.get("/dags", params={"limit": 15})
582+
583+
assert response.status_code == 200
584+
second_query_count = sum(result2.values())
585+
586+
# With n+1, adding 3 DAGs would add ~3 tag queries
587+
# Without n+1, query count should remain nearly identical
588+
assert second_query_count - first_query_count < 3, (
589+
f"Added 3 DAGs but query count increased by {second_query_count - first_query_count} "
590+
f"({first_query_count}{second_query_count}), suggesting n+1 queries for tags"
591+
)
592+
527593

528594
class TestPatchDag(TestDagEndpoint):
529595
"""Unit tests for Patch DAG."""

airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_dags.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,15 @@
2525
from sqlalchemy.orm import Session
2626

2727
from airflow.models import DagRun
28+
from airflow.models.dag import DagModel, DagTag
2829
from airflow.models.dag_favorite import DagFavorite
2930
from airflow.models.hitl import HITLDetail
3031
from airflow.sdk.timezone import utcnow
3132
from airflow.utils.session import provide_session
3233
from airflow.utils.state import DagRunState, TaskInstanceState
3334
from airflow.utils.types import DagRunTriggeredByType, DagRunType
3435

36+
from tests_common.test_utils.asserts import count_queries
3537
from unit.api_fastapi.core_api.routes.public.test_dags import (
3638
DAG1_ID,
3739
DAG2_ID,
@@ -231,6 +233,71 @@ def test_should_response_403(self, unauthorized_test_client):
231233
response = unauthorized_test_client.get("/dags", params={})
232234
assert response.status_code == 403
233235

236+
def test_get_dags_no_n_plus_one_queries(self, session, test_client):
237+
"""Test that fetching DAGs with tags doesn't trigger n+1 queries."""
238+
num_dags = 5
239+
for i in range(num_dags):
240+
dag_id = f"test_dag_queries_ui_{i}"
241+
dag_model = DagModel(
242+
dag_id=dag_id,
243+
bundle_name="dag_maker",
244+
fileloc=f"/tmp/{dag_id}.py",
245+
is_stale=False,
246+
)
247+
session.add(dag_model)
248+
session.flush()
249+
250+
for j in range(3):
251+
tag = DagTag(name=f"tag_ui_{i}_{j}", dag_id=dag_id)
252+
session.add(tag)
253+
254+
session.commit()
255+
session.expire_all()
256+
257+
with count_queries() as result:
258+
response = test_client.get("/dags", params={"limit": 10})
259+
260+
assert response.status_code == 200
261+
body = response.json()
262+
dags_with_our_prefix = [d for d in body["dags"] if d["dag_id"].startswith("test_dag_queries_ui_")]
263+
assert len(dags_with_our_prefix) == num_dags
264+
for dag in dags_with_our_prefix:
265+
assert len(dag["tags"]) == 3
266+
267+
first_query_count = sum(result.values())
268+
269+
# Add more DAGs and verify query count doesn't scale linearly
270+
for i in range(num_dags, num_dags + 3):
271+
dag_id = f"test_dag_queries_ui_{i}"
272+
dag_model = DagModel(
273+
dag_id=dag_id,
274+
bundle_name="dag_maker",
275+
fileloc=f"/tmp/{dag_id}.py",
276+
is_stale=False,
277+
)
278+
session.add(dag_model)
279+
session.flush()
280+
281+
for j in range(3):
282+
tag = DagTag(name=f"tag_ui_{i}_{j}", dag_id=dag_id)
283+
session.add(tag)
284+
285+
session.commit()
286+
session.expire_all()
287+
288+
with count_queries() as result2:
289+
response = test_client.get("/dags", params={"limit": 15})
290+
291+
assert response.status_code == 200
292+
second_query_count = sum(result2.values())
293+
294+
# With n+1, adding 3 DAGs would add ~3 tag queries
295+
# Without n+1, query count should remain nearly identical
296+
assert second_query_count - first_query_count < 3, (
297+
f"Added 3 DAGs but query count increased by {second_query_count - first_query_count} "
298+
f"({first_query_count}{second_query_count}), suggesting n+1 queries for tags"
299+
)
300+
234301
@pytest.mark.usefixtures("configure_git_connection_for_dag_bundle")
235302
def test_latest_run_should_return_200(self, test_client):
236303
response = test_client.get(f"/dags/{DAG1_ID}/latest_run")

0 commit comments

Comments
 (0)