Skip to content

Commit 1b643a4

Browse files
insomnesdstandish
andauthored
Retry k8s API requests in KubernetesPodTrigger (apache#47187)
--------- Co-authored-by: Daniel Standish <[email protected]>
1 parent 0675231 commit 1b643a4

File tree

4 files changed

+74
-8
lines changed

4 files changed

+74
-8
lines changed

providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/pod.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,10 @@ def trigger_reentry(self, context: Context, event: dict[str, Any]) -> Any:
834834
last_log_time = event.get("last_log_time")
835835

836836
if event["status"] in ("error", "failed", "timeout"):
837+
event_message = event.get("message", "No message provided")
838+
self.log.error(
839+
"Trigger emitted an %s event, failing the task: %s", event["status"], event_message
840+
)
837841
# fetch some logs when pod is failed
838842
if self.get_logs:
839843
self._write_logs(self.pod, follow=follow, since_time=last_log_time)

providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from collections.abc import AsyncIterator
2323
from enum import Enum
2424
from functools import cached_property
25-
from typing import TYPE_CHECKING, Any
25+
from typing import TYPE_CHECKING, Any, cast
26+
27+
import tenacity
28+
from kubernetes_asyncio.client.models import V1Pod
2629

2730
from airflow.providers.cncf.kubernetes.hooks.kubernetes import AsyncKubernetesHook
2831
from airflow.providers.cncf.kubernetes.utils.pod_manager import (
@@ -33,7 +36,6 @@
3336
from airflow.triggers.base import BaseTrigger, TriggerEvent
3437

3538
if TYPE_CHECKING:
36-
from kubernetes_asyncio.client.models import V1Pod
3739
from pendulum import DateTime
3840

3941

@@ -200,7 +202,7 @@ def _format_exception_description(self, exc: Exception) -> Any:
200202
async def _wait_for_pod_start(self) -> ContainerState:
201203
"""Loops until pod phase leaves ``PENDING`` If timeout is reached, throws error."""
202204
while True:
203-
pod = await self.hook.get_pod(self.pod_name, self.pod_namespace)
205+
pod = await self._get_pod()
204206
if not pod.status.phase == "Pending":
205207
return self.define_container_state(pod)
206208

@@ -223,7 +225,7 @@ async def _wait_for_container_completion(self) -> TriggerEvent:
223225
if self.logging_interval is not None:
224226
time_get_more_logs = time_begin + datetime.timedelta(seconds=self.logging_interval)
225227
while True:
226-
pod = await self.hook.get_pod(self.pod_name, self.pod_namespace)
228+
pod = await self._get_pod()
227229
container_state = self.define_container_state(pod)
228230
if container_state == ContainerState.TERMINATED:
229231
return TriggerEvent(
@@ -257,6 +259,14 @@ async def _wait_for_container_completion(self) -> TriggerEvent:
257259
self.log.debug("Sleeping for %s seconds.", self.poll_interval)
258260
await asyncio.sleep(self.poll_interval)
259261

262+
@tenacity.retry(stop=tenacity.stop_after_attempt(3), wait=tenacity.wait_exponential(), reraise=True)
263+
async def _get_pod(self) -> V1Pod:
264+
"""Get the pod from Kubernetes with retries."""
265+
pod = await self.hook.get_pod(name=self.pod_name, namespace=self.pod_namespace)
266+
# Due to AsyncKubernetesHook overriding get_pod, we need to cast the return
267+
# value to kubernetes_asyncio.V1Pod, because it's perceived as different type
268+
return cast(V1Pod, pod)
269+
260270
def _get_async_hook(self) -> AsyncKubernetesHook:
261271
# TODO: Remove this method when the min version of kubernetes provider is 7.12.0 in Google provider.
262272
return AsyncKubernetesHook(

providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_pod.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2091,10 +2091,12 @@ def test_async_create_pod_should_execute_successfully(
20912091
ti_mock.xcom_push.assert_any_call(key="pod_namespace", value=TEST_NAMESPACE)
20922092
assert isinstance(exc.value.trigger, KubernetesPodTrigger)
20932093

2094+
@pytest.mark.parametrize("status", ["error", "failed", "timeout"])
2095+
@patch(KUB_OP_PATH.format("log"))
20942096
@patch(KUB_OP_PATH.format("cleanup"))
20952097
@patch(HOOK_CLASS)
2096-
def test_async_create_pod_should_throw_exception(self, mocked_hook, mocked_cleanup):
2097-
"""Tests that an AirflowException is raised in case of error event"""
2098+
def test_async_create_pod_should_throw_exception(self, mocked_hook, mocked_cleanup, mocked_log, status):
2099+
"""Tests that an AirflowException is raised in case of error event and event is logged"""
20982100

20992101
mocked_hook.return_value.get_pod.return_value = MagicMock()
21002102
k = KubernetesPodOperator(
@@ -2111,17 +2113,21 @@ def test_async_create_pod_should_throw_exception(self, mocked_hook, mocked_clean
21112113
deferrable=True,
21122114
)
21132115

2116+
message = "Some message"
21142117
with pytest.raises(AirflowException):
21152118
k.trigger_reentry(
21162119
context=None,
21172120
event={
2118-
"status": "error",
2119-
"message": "Some error",
2121+
"status": status,
2122+
"message": message,
21202123
"name": TEST_NAME,
21212124
"namespace": TEST_NAMESPACE,
21222125
},
21232126
)
21242127

2128+
log_message = "Trigger emitted an %s event, failing the task: %s"
2129+
mocked_log.error.assert_called_once_with(log_message, status, message)
2130+
21252131
@pytest.mark.parametrize(
21262132
"kwargs, actual_exit_code, expected_exc, pod_status, event_status",
21272133
[

providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from __future__ import annotations
1919

2020
import asyncio
21+
import contextlib
2122
import datetime
2223
import logging
2324
from asyncio import Future
@@ -386,3 +387,48 @@ async def test_run_loop_return_success_for_completed_pod_after_timeout(
386387
)
387388
== actual
388389
)
390+
391+
@pytest.mark.asyncio
392+
@mock.patch(f"{TRIGGER_PATH}.hook")
393+
async def test__get_pod(self, mock_hook, trigger):
394+
"""
395+
Test that KubernetesPodTrigger _get_pod is called with the correct arguments.
396+
"""
397+
398+
mock_hook.get_pod.return_value = self._mock_pod_result(mock.MagicMock())
399+
400+
await trigger._get_pod()
401+
mock_hook.get_pod.assert_called_with(name=POD_NAME, namespace=NAMESPACE)
402+
403+
@pytest.mark.asyncio
404+
@pytest.mark.parametrize(
405+
"exc_count, call_count",
406+
[
407+
pytest.param(0, 1, id="no exception"),
408+
pytest.param(2, 3, id="2 exc, 1 success"),
409+
pytest.param(3, 3, id="max retries"),
410+
],
411+
)
412+
@mock.patch(f"{TRIGGER_PATH}.hook")
413+
async def test__get_pod_retries(
414+
self,
415+
mock_hook,
416+
trigger,
417+
exc_count,
418+
call_count,
419+
):
420+
"""
421+
Test that KubernetesPodTrigger _get_pod retries in case of an exception during
422+
the hook.get_pod call.
423+
"""
424+
425+
side_effects = [Exception("Test exception") for _ in range(exc_count)] + [MagicMock()]
426+
427+
mock_hook.get_pod.side_effect = mock.AsyncMock(side_effect=side_effects)
428+
# We expect the exception to be raised only if the number of retries is exceeded
429+
context = (
430+
pytest.raises(Exception, match="Test exception") if exc_count > 2 else contextlib.nullcontext()
431+
)
432+
with context:
433+
await trigger._get_pod()
434+
assert mock_hook.get_pod.call_count == call_count

0 commit comments

Comments
 (0)