Skip to content

Commit 50e10d6

Browse files
authored
Make Edge provider SQLA2 compatible (#59414)
* Make Edge provider SQLA2 compatible * Add pre-commit check * Change filter_by to where * Uuups, fix * Review feedback: filter() -> where() * Review feedback: scalars()....first() -> scalar() * Fix mypy with SQLA2
1 parent 1c70306 commit 50e10d6

File tree

7 files changed

+51
-44
lines changed

7 files changed

+51
-44
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ repos:
440440
^providers/celery/.*\.py$|
441441
^providers/cncf/kubernetes/.*\.py$|
442442
^providers/databricks/.*\.py$|
443+
^providers/edge3/.*\.py$|
443444
^providers/mysql/.*\.py$|
444445
^providers/openlineage/.*\.py$|
445446
^task_sdk.*\.py$

providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from datetime import datetime, timedelta
2424
from typing import TYPE_CHECKING, Any
2525

26-
from sqlalchemy import delete, inspect, text
26+
from sqlalchemy import delete, inspect, select, text
2727
from sqlalchemy.exc import NoSuchTableError
2828
from sqlalchemy.orm import Session
2929

@@ -140,17 +140,15 @@ def queue_workload(
140140
key = task_instance.key
141141

142142
# Check if job already exists with same dag_id, task_id, run_id, map_index, try_number
143-
existing_job = (
144-
session.query(EdgeJobModel)
145-
.filter_by(
146-
dag_id=key.dag_id,
147-
task_id=key.task_id,
148-
run_id=key.run_id,
149-
map_index=key.map_index,
150-
try_number=key.try_number,
143+
existing_job = session.scalars(
144+
select(EdgeJobModel).where(
145+
EdgeJobModel.dag_id == key.dag_id,
146+
EdgeJobModel.task_id == key.task_id,
147+
EdgeJobModel.run_id == key.run_id,
148+
EdgeJobModel.map_index == key.map_index,
149+
EdgeJobModel.try_number == key.try_number,
151150
)
152-
.first()
153-
)
151+
).first()
154152

155153
if existing_job:
156154
existing_job.state = TaskInstanceState.QUEUED
@@ -176,10 +174,10 @@ def _check_worker_liveness(self, session: Session) -> bool:
176174
"""Reset worker state if heartbeat timed out."""
177175
changed = False
178176
heartbeat_interval: int = conf.getint("edge", "heartbeat_interval")
179-
lifeless_workers: list[EdgeWorkerModel] = (
180-
session.query(EdgeWorkerModel)
177+
lifeless_workers: Sequence[EdgeWorkerModel] = session.scalars(
178+
select(EdgeWorkerModel)
181179
.with_for_update(skip_locked=True)
182-
.filter(
180+
.where(
183181
EdgeWorkerModel.state.not_in(
184182
[
185183
EdgeWorkerState.UNKNOWN,
@@ -189,8 +187,7 @@ def _check_worker_liveness(self, session: Session) -> bool:
189187
),
190188
EdgeWorkerModel.last_update < (timezone.utcnow() - timedelta(seconds=heartbeat_interval * 5)),
191189
)
192-
.all()
193-
)
190+
).all()
194191

195192
for worker in lifeless_workers:
196193
changed = True
@@ -212,15 +209,14 @@ def _check_worker_liveness(self, session: Session) -> bool:
212209
def _update_orphaned_jobs(self, session: Session) -> bool:
213210
"""Update status ob jobs when workers die and don't update anymore."""
214211
heartbeat_interval: int = conf.getint("scheduler", "task_instance_heartbeat_timeout")
215-
lifeless_jobs: list[EdgeJobModel] = (
216-
session.query(EdgeJobModel)
212+
lifeless_jobs: Sequence[EdgeJobModel] = session.scalars(
213+
select(EdgeJobModel)
217214
.with_for_update(skip_locked=True)
218-
.filter(
215+
.where(
219216
EdgeJobModel.state == TaskInstanceState.RUNNING,
220217
EdgeJobModel.last_update < (timezone.utcnow() - timedelta(seconds=heartbeat_interval)),
221218
)
222-
.all()
223-
)
219+
).all()
224220

225221
for job in lifeless_jobs:
226222
ti = TaskInstance.get_task_instance(
@@ -254,10 +250,10 @@ def _purge_jobs(self, session: Session) -> bool:
254250
purged_marker = False
255251
job_success_purge = conf.getint("edge", "job_success_purge")
256252
job_fail_purge = conf.getint("edge", "job_fail_purge")
257-
jobs: list[EdgeJobModel] = (
258-
session.query(EdgeJobModel)
253+
jobs: Sequence[EdgeJobModel] = session.scalars(
254+
select(EdgeJobModel)
259255
.with_for_update(skip_locked=True)
260-
.filter(
256+
.where(
261257
EdgeJobModel.state.in_(
262258
[
263259
TaskInstanceState.RUNNING,
@@ -269,8 +265,7 @@ def _purge_jobs(self, session: Session) -> bool:
269265
]
270266
)
271267
)
272-
.all()
273-
)
268+
).all()
274269

275270
# Sync DB with executor otherwise runs out of sync in multi scheduler deployment
276271
already_removed = self.running - set(job.key for job in jobs)

providers/edge3/tests/unit/edge3/executors/test_edge_executor.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import pytest
2323
import time_machine
24+
from sqlalchemy import delete, select
2425

2526
from airflow.configuration import conf
2627
from airflow.models.taskinstancekey import TaskInstanceKey
@@ -40,7 +41,7 @@ class TestEdgeExecutor:
4041
@pytest.fixture(autouse=True)
4142
def setup_test_cases(self):
4243
with create_session() as session:
43-
session.query(EdgeJobModel).delete()
44+
session.execute(delete(EdgeJobModel))
4445

4546
def get_test_executor(self, pool_slots=1):
4647
key = TaskInstanceKey(
@@ -104,7 +105,7 @@ def test_sync_orphaned_tasks(self, mock_stats_incr):
104105
mock_stats_incr.call_count == 2
105106

106107
with create_session() as session:
107-
jobs = session.query(EdgeJobModel).all()
108+
jobs = session.scalars(select(EdgeJobModel)).all()
108109
assert len(jobs) == 1
109110
assert jobs[0].task_id == "started_running_orphaned"
110111
assert jobs[0].state == TaskInstanceState.REMOVED
@@ -154,8 +155,8 @@ def remove_from_running(key: TaskInstanceKey):
154155
executor.sync()
155156

156157
with create_session() as session:
157-
jobs = session.query(EdgeJobModel).all()
158-
assert len(session.query(EdgeJobModel).all()) == 1
158+
jobs = session.scalars(select(EdgeJobModel)).all()
159+
assert len(session.scalars(select(EdgeJobModel)).all()) == 1
159160
assert jobs[0].task_id == "started_running"
160161
assert jobs[0].state == TaskInstanceState.RUNNING
161162

@@ -215,7 +216,7 @@ def test_sync_active_worker(self):
215216
# Prepare some data
216217
with create_session() as session:
217218
# Clear existing workers to avoid unique constraint violation
218-
session.query(EdgeWorkerModel).delete()
219+
session.execute(delete(EdgeWorkerModel))
219220
session.commit()
220221

221222
# Add workers with different states
@@ -253,7 +254,7 @@ def test_sync_active_worker(self):
253254
executor.sync()
254255

255256
with create_session() as session:
256-
for worker in session.query(EdgeWorkerModel).all():
257+
for worker in session.scalars(select(EdgeWorkerModel)).all():
257258
print(worker.worker_name)
258259
if "maintenance_" in worker.worker_name:
259260
EdgeWorkerState.OFFLINE_MAINTENANCE
@@ -304,7 +305,7 @@ def test_revoke_task(self):
304305

305306
# Verify job exists before revoke
306307
with create_session() as session:
307-
jobs = session.query(EdgeJobModel).all()
308+
jobs = session.scalars(select(EdgeJobModel)).all()
308309
assert len(jobs) == 1
309310

310311
# Revoke the task
@@ -317,7 +318,7 @@ def test_revoke_task(self):
317318

318319
# Verify job is removed from database
319320
with create_session() as session:
320-
jobs = session.query(EdgeJobModel).all()
321+
jobs = session.scalars(select(EdgeJobModel)).all()
321322
assert len(jobs) == 0
322323

323324
def test_revoke_task_nonexistent(self):

providers/edge3/tests/unit/edge3/worker_api/routes/test_jobs.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from unittest.mock import patch
2121

2222
import pytest
23+
from sqlalchemy import delete, select
2324

2425
from airflow.providers.common.compat.sdk import Stats
2526
from airflow.providers.edge3.models.edge_job import EdgeJobModel
@@ -42,7 +43,7 @@
4243
class TestJobsApiRoutes:
4344
@pytest.fixture(autouse=True)
4445
def setup_test_cases(self, dag_maker, session: Session):
45-
session.query(EdgeJobModel).delete()
46+
session.execute(delete(EdgeJobModel))
4647
session.commit()
4748

4849
@patch(f"{Stats.__module__}.Stats.incr")
@@ -94,4 +95,6 @@ def test_state(self, mock_stats_incr, session: Session):
9495
)
9596
mock_stats_incr.call_count == 2
9697

97-
assert session.query(EdgeJobModel).scalar().state == TaskInstanceState.SUCCESS
98+
db_job: EdgeJobModel | None = session.scalar(select(EdgeJobModel))
99+
assert db_job is not None
100+
assert db_job.state == TaskInstanceState.SUCCESS

providers/edge3/tests/unit/edge3/worker_api/routes/test_logs.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
from collections.abc import Sequence
1920
from typing import TYPE_CHECKING
2021

2122
import pytest
23+
from sqlalchemy import delete, select
2224

2325
from airflow.providers.common.compat.sdk import timezone
2426
from airflow.providers.edge3.models.edge_logs import EdgeLogsModel
@@ -45,7 +47,7 @@ def setup_test_cases(self, dag_maker, session: Session):
4547
EmptyOperator(task_id=TASK_ID)
4648
dag_maker.create_dagrun(run_id=RUN_ID)
4749

48-
session.query(EdgeLogsModel).delete()
50+
session.execute(delete(EdgeLogsModel))
4951
session.commit()
5052

5153
def test_logfile_path(self, session: Session):
@@ -68,7 +70,7 @@ def test_push_logs(self, session: Session):
6870
body=log_data,
6971
session=session,
7072
)
71-
logs: list[EdgeLogsModel] = session.query(EdgeLogsModel).all()
73+
logs: Sequence[EdgeLogsModel] = session.scalars(select(EdgeLogsModel)).all()
7274
assert len(logs) == 1
7375
assert logs[0].dag_id == DAG_ID
7476
assert logs[0].task_id == TASK_ID

providers/edge3/tests/unit/edge3/worker_api/routes/test_ui.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from typing import TYPE_CHECKING
2020

2121
import pytest
22+
from sqlalchemy import delete
2223

2324
from airflow.providers.edge3.models.edge_worker import EdgeWorkerModel, EdgeWorkerState
2425

@@ -34,7 +35,7 @@
3435
class TestUiApiRoutes:
3536
@pytest.fixture(autouse=True)
3637
def setup_test_cases(self, session: Session):
37-
session.query(EdgeWorkerModel).delete()
38+
session.execute(delete(EdgeWorkerModel))
3839
session.add(EdgeWorkerModel(worker_name="worker1", queues=["default"], state=EdgeWorkerState.RUNNING))
3940
session.commit()
4041

providers/edge3/tests/unit/edge3/worker_api/routes/test_worker.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
# under the License.
1717
from __future__ import annotations
1818

19+
from collections.abc import Sequence
1920
from pathlib import Path
2021
from typing import TYPE_CHECKING
2122

2223
import pytest
2324
from fastapi import HTTPException
25+
from sqlalchemy import delete, select
2426

2527
from airflow.providers.common.compat.sdk import timezone
2628
from airflow.providers.edge3.cli.worker import EdgeWorker
@@ -47,7 +49,7 @@ def cli_worker(self, tmp_path: Path) -> EdgeWorker:
4749

4850
@pytest.fixture(autouse=True)
4951
def setup_test_cases(self, session: Session):
50-
session.query(EdgeWorkerModel).delete()
52+
session.execute(delete(EdgeWorkerModel))
5153

5254
def test_assert_version(self):
5355
from airflow import __version__ as airflow_version
@@ -87,7 +89,7 @@ def test_register(self, session: Session, input_queues: list[str] | None, cli_wo
8789
register("test_worker", body, session)
8890
session.commit()
8991

90-
worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
92+
worker: Sequence[EdgeWorkerModel] = session.scalars(select(EdgeWorkerModel)).all()
9193
assert len(worker) == 1
9294
assert worker[0].worker_name == "test_worker"
9395
if input_queues:
@@ -138,7 +140,9 @@ def test_register_duplicate_worker(
138140
# Should succeed for offline/unknown states
139141
register("test_worker", body, session)
140142
session.commit()
141-
worker = session.query(EdgeWorkerModel).filter_by(worker_name="test_worker").first()
143+
worker = session.execute(
144+
select(EdgeWorkerModel).where(EdgeWorkerModel.worker_name == "test_worker")
145+
).scalar_one_or_none()
142146
assert worker is not None
143147
# State should be updated (or redefined based on redefine_state logic)
144148
assert worker.state is not None
@@ -237,7 +241,7 @@ def test_set_state(self, session: Session, cli_worker: EdgeWorker):
237241
)
238242
return_queues = set_state("test2_worker", body, session).queues
239243

240-
worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
244+
worker: Sequence[EdgeWorkerModel] = session.scalars(select(EdgeWorkerModel)).all()
241245
assert len(worker) == 1
242246
assert worker[0].worker_name == "test2_worker"
243247
assert worker[0].state == EdgeWorkerState.RUNNING
@@ -271,7 +275,7 @@ def test_update_queues(
271275
session.commit()
272276
body = WorkerQueueUpdateBody(new_queues=add_queues, remove_queues=remove_queues)
273277
update_queues("test2_worker", body, session)
274-
worker: list[EdgeWorkerModel] = session.query(EdgeWorkerModel).all()
278+
worker: Sequence[EdgeWorkerModel] = session.scalars(select(EdgeWorkerModel)).all()
275279
assert len(worker) == 1
276280
assert worker[0].worker_name == "test2_worker"
277281
assert len(expected_queues) == len(worker[0].queues or [])

0 commit comments

Comments
 (0)