Skip to content
This repository was archived by the owner on Apr 3, 2024. It is now read-only.

Commit 73f70e0

Browse files
authored
Merge pull request #67 from github/feature/gh-2.0.2_scheduler
backport of fix for scheduler adoption
2 parents e61b945 + 9636f32 commit 73f70e0

File tree

4 files changed

+92
-5
lines changed

4 files changed

+92
-5
lines changed

airflow/executors/kubernetes_executor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -586,15 +586,15 @@ def _change_state(self, key: TaskInstanceKey, state: Optional[str], pod_id: str,
586586
self.event_buffer[key] = state, None
587587

588588
def try_adopt_task_instances(self, tis: List[TaskInstance]) -> List[TaskInstance]:
589-
tis_to_flush = [ti for ti in tis if not ti.external_executor_id]
590-
scheduler_job_ids = [ti.external_executor_id for ti in tis]
589+
tis_to_flush = [ti for ti in tis if not ti.queued_by_job_id]
590+
scheduler_job_ids = {ti.queued_by_job_id for ti in tis}
591591
pod_ids = {
592592
create_pod_id(
593593
dag_id=pod_generator.make_safe_label_value(ti.dag_id),
594594
task_id=pod_generator.make_safe_label_value(ti.task_id),
595595
): ti
596596
for ti in tis
597-
if ti.external_executor_id
597+
if ti.queued_by_job_id
598598
}
599599
kube_client: client.CoreV1Api = self.kube_client
600600
for scheduler_job_id in scheduler_job_ids:

airflow/jobs/scheduler_job.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from contextlib import redirect_stderr, redirect_stdout, suppress
3232
from datetime import timedelta
3333
from multiprocessing.connection import Connection as MultiprocessingConnection
34-
from typing import Any, Callable, DefaultDict, Dict, Iterable, List, Optional, Set, Tuple
34+
from typing import Any, Callable, DefaultDict, Dict, Iterator, Iterable, List, Optional, Set, Tuple
3535

3636
from setproctitle import setproctitle
3737
from sqlalchemy import and_, func, not_, or_, tuple_
@@ -1218,7 +1218,15 @@ def _process_executor_events(self, session: Session = None) -> int:
12181218

12191219
# Check state of finished tasks
12201220
filter_for_tis = TI.filter_for_tis(tis_with_right_state)
1221-
tis: List[TI] = session.query(TI).filter(filter_for_tis).options(selectinload('dag_model')).all()
1221+
query = session.query(TI).filter(filter_for_tis).options(selectinload('dag_model'))
1222+
# row lock this entire set of taskinstances to make sure the scheduler doesn't fail when we have
1223+
# multi-schedulers
1224+
tis: Iterator[TI] = with_row_locks(
1225+
query,
1226+
of=TI,
1227+
session=session,
1228+
**skip_locked(session=session),
1229+
)
12221230
for ti in tis:
12231231
try_number = ti_primary_key_to_try_number_map[ti.key.primary]
12241232
buffer_key = ti.key.with_try_number(try_number)

airflow/models/taskinstance.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,10 @@ def refresh_from_db(self, session=None, lock_for_update=False) -> None:
623623
self.priority_weight = ti.priority_weight
624624
self.operator = ti.operator
625625
self.queued_dttm = ti.queued_dttm
626+
self.queued_by_job_id = ti.queued_by_job_id
626627
self.pid = ti.pid
628+
self.executor_config = ti.executor_config
629+
self.external_executor_id = ti.external_executor_id
627630
else:
628631
self.state = None
629632

tests/executors/test_kubernetes_executor.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,82 @@ def test_not_adopt_unassigned_task(self, mock_kube_client):
419419
assert not mock_kube_client.patch_namespaced_pod.called
420420
assert pod_ids == {"foobar": {}}
421421

422+
@mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor.adopt_launched_task')
423+
@mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor._adopt_completed_pods')
424+
def test_try_adopt_task_instances(self, mock_adopt_completed_pods, mock_adopt_launched_task):
425+
executor = self.kubernetes_executor
426+
executor.scheduler_job_id = "10"
427+
mock_ti = mock.MagicMock(queued_by_job_id="1", external_executor_id="1", dag_id="dag", task_id="task")
428+
pod = k8s.V1Pod(metadata=k8s.V1ObjectMeta(name="foo", labels={"dag_id": "dag", "task_id": "task"}))
429+
pod_id = create_pod_id(dag_id="dag", task_id="task")
430+
mock_kube_client = mock.MagicMock()
431+
mock_kube_client.list_namespaced_pod.return_value.items = [pod]
432+
executor.kube_client = mock_kube_client
433+
434+
# First adoption
435+
executor.try_adopt_task_instances([mock_ti])
436+
mock_kube_client.list_namespaced_pod.assert_called_once_with(
437+
namespace='default', label_selector='airflow-worker=1'
438+
)
439+
mock_adopt_launched_task.assert_called_once_with(mock_kube_client, pod, {pod_id: mock_ti})
440+
mock_adopt_completed_pods.assert_called_once()
441+
# We aren't checking the return value of `try_adopt_task_instances` because it relies on
442+
# `adopt_launched_task` mutating its arg. This should be refactored, but not right now.
443+
444+
# Second adoption (queued_by_job_id and external_executor_id no longer match)
445+
mock_kube_client.reset_mock()
446+
mock_adopt_launched_task.reset_mock()
447+
mock_adopt_completed_pods.reset_mock()
448+
449+
mock_ti.queued_by_job_id = "10" # scheduler_job would have updated this after the first adoption
450+
executor.scheduler_job_id = "20"
451+
452+
executor.try_adopt_task_instances([mock_ti])
453+
mock_kube_client.list_namespaced_pod.assert_called_once_with(
454+
namespace='default', label_selector='airflow-worker=10'
455+
)
456+
mock_adopt_launched_task.assert_called_once_with(mock_kube_client, pod, {pod_id: mock_ti})
457+
mock_adopt_completed_pods.assert_called_once()
458+
459+
@mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor._adopt_completed_pods')
460+
def test_try_adopt_task_instances_multiple_scheduler_ids(self, mock_adopt_completed_pods):
461+
"""We try to find pods only once per scheduler id"""
462+
executor = self.kubernetes_executor
463+
mock_kube_client = mock.MagicMock()
464+
executor.kube_client = mock_kube_client
465+
466+
mock_tis = [
467+
mock.MagicMock(queued_by_job_id="10", external_executor_id="1", dag_id="dag", task_id="task"),
468+
mock.MagicMock(queued_by_job_id="40", external_executor_id="1", dag_id="dag", task_id="task2"),
469+
mock.MagicMock(queued_by_job_id="40", external_executor_id="1", dag_id="dag", task_id="task3"),
470+
]
471+
472+
executor.try_adopt_task_instances(mock_tis)
473+
assert mock_kube_client.list_namespaced_pod.call_count == 2
474+
mock_kube_client.list_namespaced_pod.assert_has_calls(
475+
[
476+
mock.call(namespace='default', label_selector='airflow-worker=10'),
477+
mock.call(namespace='default', label_selector='airflow-worker=40'),
478+
],
479+
any_order=True,
480+
)
481+
482+
@mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor.adopt_launched_task')
483+
@mock.patch('airflow.executors.kubernetes_executor.KubernetesExecutor._adopt_completed_pods')
484+
def test_try_adopt_task_instances_no_matching_pods(
485+
self, mock_adopt_completed_pods, mock_adopt_launched_task
486+
):
487+
executor = self.kubernetes_executor
488+
mock_ti = mock.MagicMock(queued_by_job_id="1", external_executor_id="1", dag_id="dag", task_id="task")
489+
mock_kube_client = mock.MagicMock()
490+
mock_kube_client.list_namespaced_pod.return_value.items = []
491+
executor.kube_client = mock_kube_client
492+
493+
tis_to_flush = executor.try_adopt_task_instances([mock_ti])
494+
assert tis_to_flush == [mock_ti]
495+
mock_adopt_launched_task.assert_not_called()
496+
mock_adopt_completed_pods.assert_called_once()
497+
422498

423499
class TestKubernetesJobWatcher(unittest.TestCase):
424500
def setUp(self):

0 commit comments

Comments
 (0)