Skip to content

Commit 17fbc9d

Browse files
authored
SQLA2 public tests (test_backfills, test_connections, test_pools) (#59733)
* sqla2 public tests & pre-commit * scalars to scalar for one
1 parent eb60174 commit 17fbc9d

File tree

4 files changed

+31
-26
lines changed

4 files changed

+31
-26
lines changed

.pre-commit-config.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,9 @@ repos:
426426
^airflow-core/src/airflow/models/.*\.py$|
427427
^airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py$|
428428
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$|
429+
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_backfills.py$|
430+
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py$|
431+
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py$|
429432
^airflow-core/tests/unit/models/test_serialized_dag.py$|
430433
^airflow-core/tests/unit/models/test_pool.py$|
431434
^airflow-core/tests/unit/models/test_trigger.py$|

airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_backfills.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ class TestCreateBackfill(TestBackfillEndpoint):
205205
def test_create_backfill(self, repro_act, repro_exp, session, dag_maker, test_client):
206206
with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="0 * * * *") as dag:
207207
EmptyOperator(task_id="mytask")
208-
session.query(DagModel).all()
208+
session.scalars(select(DagModel)).all()
209209
session.commit()
210210
from_date = pendulum.parse("2024-01-01")
211211
from_date_iso = to_iso(from_date)
@@ -244,7 +244,7 @@ def test_create_backfill(self, repro_act, repro_exp, session, dag_maker, test_cl
244244
check_last_log(session, dag_id="TEST_DAG_1", event="create_backfill", logical_date=None)
245245

246246
def test_dag_not_exist(self, session, test_client):
247-
session.query(DagModel).all()
247+
session.scalars(select(DagModel)).all()
248248
session.commit()
249249
from_date = pendulum.parse("2024-01-01")
250250
from_date_iso = to_iso(from_date)
@@ -270,7 +270,7 @@ def test_dag_not_exist(self, session, test_client):
270270
def test_no_schedule_dag(self, session, dag_maker, test_client):
271271
with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="None") as dag:
272272
EmptyOperator(task_id="mytask")
273-
session.query(DagModel).all()
273+
session.scalars(select(DagModel)).all()
274274
session.commit()
275275
from_date = pendulum.parse("2024-01-01")
276276
from_date_iso = to_iso(from_date)
@@ -306,7 +306,7 @@ def test_create_backfill_with_depends_on_past(
306306
):
307307
with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="0 * * * *") as dag:
308308
EmptyOperator(task_id="mytask", depends_on_past=True)
309-
session.query(DagModel).all()
309+
session.scalars(select(DagModel)).all()
310310
session.commit()
311311
from_date = pendulum.parse("2024-01-01")
312312
from_date_iso = to_iso(from_date)
@@ -350,7 +350,7 @@ def test_create_backfill_with_depends_on_past(
350350
def test_create_backfill_future_dates(self, session, dag_maker, test_client, run_backwards):
351351
with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="0 * * * *") as dag:
352352
EmptyOperator(task_id="mytask")
353-
session.query(DagModel).all()
353+
session.scalars(select(DagModel)).all()
354354
session.commit()
355355
from_date = timezone.utcnow() + timedelta(days=1)
356356
to_date = timezone.utcnow() + timedelta(days=1)
@@ -381,7 +381,7 @@ def test_create_backfill_future_dates(self, session, dag_maker, test_client, run
381381
def test_create_backfill_past_future_dates(self, session, dag_maker, test_client, run_backwards):
382382
with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="@daily") as dag:
383383
EmptyOperator(task_id="mytask")
384-
session.query(DagModel).all()
384+
session.scalars(select(DagModel)).all()
385385
session.commit()
386386
from_date = timezone.utcnow() - timedelta(days=2)
387387
to_date = timezone.utcnow() + timedelta(days=1)
@@ -543,7 +543,7 @@ def test_create_backfill_with_existing_runs(
543543
def test_should_respond_401(self, unauthenticated_test_client, dag_maker, session):
544544
with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="0 * * * *") as dag:
545545
EmptyOperator(task_id="mytask")
546-
session.query(DagModel).all()
546+
session.scalars(select(DagModel)).all()
547547
session.commit()
548548
from_date = pendulum.parse("2024-01-01")
549549
from_date_iso = to_iso(from_date)
@@ -564,7 +564,7 @@ def test_should_respond_401(self, unauthenticated_test_client, dag_maker, sessio
564564
def test_should_respond_403(self, unauthorized_test_client, dag_maker, session):
565565
with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="0 * * * *") as dag:
566566
EmptyOperator(task_id="mytask")
567-
session.query(DagModel).all()
567+
session.scalars(select(DagModel)).all()
568568
session.commit()
569569
from_date = pendulum.parse("2024-01-01")
570570
from_date_iso = to_iso(from_date)
@@ -682,7 +682,7 @@ def test_create_backfill_dry_run_with_depends_on_past(
682682
):
683683
with dag_maker(session=session, dag_id="TEST_DAG_1", schedule="0 * * * *") as dag:
684684
EmptyOperator(task_id="mytask", depends_on_past=True)
685-
session.query(DagModel).all()
685+
session.scalars(select(DagModel)).all()
686686
session.commit()
687687
from_date = pendulum.parse("2024-01-01")
688688
from_date_iso = to_iso(from_date)

airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_connections.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from unittest import mock
2222

2323
import pytest
24+
from sqlalchemy import select
2425
from sqlalchemy.orm import Session
2526

2627
from airflow.models import Connection
@@ -104,11 +105,11 @@ def create_connections(self):
104105
class TestDeleteConnection(TestConnectionEndpoint):
105106
def test_delete_should_respond_204(self, test_client, session):
106107
self.create_connection()
107-
conns = session.query(Connection).all()
108+
conns = session.scalars(select(Connection)).all()
108109
assert len(conns) == 1
109110
response = test_client.delete(f"/connections/{TEST_CONN_ID}")
110111
assert response.status_code == 204
111-
connection = session.query(Connection).all()
112+
connection = session.scalars(select(Connection)).all()
112113
assert len(connection) == 0
113114
_check_last_log(session, dag_id=None, event="delete_connection", logical_date=None)
114115

@@ -153,7 +154,7 @@ def test_get_should_respond_404(self, test_client):
153154

154155
def test_get_should_respond_200_with_extra(self, test_client, session):
155156
self.create_connection()
156-
connection = session.query(Connection).first()
157+
connection = session.scalars(select(Connection)).first()
157158
connection.extra = '{"extra_key": "extra_value"}'
158159
session.commit()
159160
response = test_client.get(f"/connections/{TEST_CONN_ID}")
@@ -166,7 +167,7 @@ def test_get_should_respond_200_with_extra(self, test_client, session):
166167
@pytest.mark.enable_redact
167168
def test_get_should_respond_200_with_extra_redacted(self, test_client, session):
168169
self.create_connection()
169-
connection = session.query(Connection).first()
170+
connection = session.scalars(select(Connection)).first()
170171
connection.extra = '{"password": "test-password"}'
171172
session.commit()
172173
response = test_client.get(f"/connections/{TEST_CONN_ID}")
@@ -277,7 +278,7 @@ class TestPostConnection(TestConnectionEndpoint):
277278
def test_post_should_respond_201(self, test_client, session, body):
278279
response = test_client.post("/connections", json=body)
279280
assert response.status_code == 201
280-
connection = session.query(Connection).all()
281+
connection = session.scalars(select(Connection)).all()
281282
assert len(connection) == 1
282283
_check_last_log(session, dag_id=None, event="post_connection", logical_date=None)
283284

@@ -780,7 +781,7 @@ def test_patch_should_respond_200_with_update_mask(
780781
self.create_connection()
781782
response = test_client.patch(f"/connections/{TEST_CONN_ID}", json=body, params=update_mask)
782783
assert response.status_code == 200
783-
connection = session.query(Connection).filter_by(conn_id=TEST_CONN_ID).first()
784+
connection = session.scalars(select(Connection).where(Connection.conn_id == TEST_CONN_ID)).first()
784785
assert connection.password is None
785786
assert response.json() == updated_connection
786787

@@ -1399,7 +1400,7 @@ def test_post_should_accept_empty_string_as_extra(self, test_client, session):
13991400
response = test_client.post("/connections", json=body)
14001401
assert response.status_code == 201
14011402

1402-
connection = session.query(Connection).filter_by(conn_id=TEST_CONN_ID).first()
1403+
connection = session.scalars(select(Connection).where(Connection.conn_id == TEST_CONN_ID)).first()
14031404
assert connection is not None
14041405
assert connection.extra == "{}" # Backward compatibility: treat "" as empty JSON object
14051406

airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_pools.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from unittest import mock
2020

2121
import pytest
22+
from sqlalchemy import func, select
2223

2324
from airflow.models.pool import Pool
2425
from airflow.utils.session import provide_session
@@ -73,11 +74,11 @@ def create_pools(self):
7374
class TestDeletePool(TestPoolsEndpoint):
7475
def test_delete_should_respond_204(self, test_client, session):
7576
self.create_pools()
76-
pools = session.query(Pool).all()
77+
pools = session.scalars(select(Pool)).all()
7778
assert len(pools) == 4
7879
response = test_client.delete(f"/pools/{POOL1_NAME}")
7980
assert response.status_code == 204
80-
pools = session.query(Pool).all()
81+
pools = session.scalars(select(Pool)).all()
8182
assert len(pools) == 3
8283
check_last_log(session, dag_id=None, event="delete_pool", logical_date=None)
8384

@@ -104,11 +105,11 @@ def test_delete_should_respond_404(self, test_client):
104105
def test_delete_pool3_should_respond_204(self, test_client, session):
105106
"""Test deleting POOL3 with forward slash in name"""
106107
self.create_pools()
107-
pools = session.query(Pool).all()
108+
pools = session.scalars(select(Pool)).all()
108109
assert len(pools) == 4
109110
response = test_client.delete(f"/pools/{POOL3_NAME}")
110111
assert response.status_code == 204
111-
pools = session.query(Pool).all()
112+
pools = session.scalars(select(Pool)).all()
112113
assert len(pools) == 3
113114
check_last_log(session, dag_id=None, event="delete_pool", logical_date=None)
114115

@@ -430,12 +431,12 @@ class TestPostPool(TestPoolsEndpoint):
430431
)
431432
def test_should_respond_200(self, test_client, session, body, expected_status_code, expected_response):
432433
self.create_pools()
433-
n_pools = session.query(Pool).count()
434+
n_pools = session.scalar(select(func.count()).select_from(Pool))
434435
response = test_client.post("/pools", json=body)
435436
assert response.status_code == expected_status_code
436437

437438
assert response.json() == expected_response
438-
assert session.query(Pool).count() == n_pools + 1
439+
assert session.scalar(select(func.count()).select_from(Pool)) == n_pools + 1
439440
check_last_log(session, dag_id=None, event="post_pool", logical_date=None)
440441

441442
def test_should_respond_401(self, unauthenticated_test_client):
@@ -486,11 +487,11 @@ def test_should_response_409(
486487
second_expected_response,
487488
):
488489
self.create_pools()
489-
n_pools = session.query(Pool).count()
490+
n_pools = session.scalar(select(func.count()).select_from(Pool))
490491
response = test_client.post("/pools", json=body)
491492
assert response.status_code == first_expected_status_code
492493
assert response.json() == first_expected_response
493-
assert session.query(Pool).count() == n_pools + 1
494+
assert session.scalar(select(func.count()).select_from(Pool)) == n_pools + 1
494495
response = test_client.post("/pools", json=body)
495496
assert response.status_code == second_expected_status_code
496497
if second_expected_status_code == 201:
@@ -500,7 +501,7 @@ def test_should_response_409(
500501
assert "detail" in response_json
501502
assert list(response_json["detail"].keys()) == ["reason", "statement", "orig_error", "message"]
502503

503-
assert session.query(Pool).count() == n_pools + 1
504+
assert session.scalar(select(func.count()).select_from(Pool)) == n_pools + 1
504505

505506

506507
class TestBulkPools(TestPoolsEndpoint):
@@ -990,7 +991,7 @@ def test_update_mask_preserves_other_fields(self, test_client, session):
990991
assert response_data["update"]["success"] == ["pool1"]
991992

992993
# Assert: fetch from DB and check only masked field changed
993-
updated_pool = session.query(Pool).filter_by(pool="pool1").one()
994+
updated_pool = session.execute(select(Pool).where(Pool.pool == "pool1")).scalar_one()
994995
assert updated_pool.slots == 50 # updated
995996
assert updated_pool.description is None # unchanged
996997
assert updated_pool.include_deferred is True # unchanged

0 commit comments

Comments
 (0)