diff --git a/databricks/sdk/retries.py b/databricks/sdk/retries.py index 5528a8978..a6cf5d8dc 100644 --- a/databricks/sdk/retries.py +++ b/databricks/sdk/retries.py @@ -1,13 +1,15 @@ import functools import logging from datetime import timedelta -from random import random -from typing import Callable, Optional, Sequence, Type +from random import random, uniform +from typing import Callable, Optional, Sequence, Tuple, Type, TypeVar from .clock import Clock, RealClock logger = logging.getLogger(__name__) +T = TypeVar("T") + def retried( *, @@ -67,3 +69,101 @@ def wrapper(*args, **kwargs): return wrapper return decorator + + +class RetryError(Exception): + """Error that can be returned from poll functions to control retry behavior.""" + + def __init__(self, err: Exception, halt: bool = False): + self.err = err + self.halt = halt + super().__init__(str(err)) + + @staticmethod + def continues(msg: str) -> "RetryError": + """Create a non-halting retry error with a message.""" + return RetryError(Exception(msg), halt=False) + + @staticmethod + def halt(err: Exception) -> "RetryError": + """Create a halting retry error.""" + return RetryError(err, halt=True) + + +def _backoff(attempt: int) -> float: + """Calculate backoff time with jitter. + + Linear backoff: attempt * 1 second, capped at 10 seconds + Plus random jitter between 50ms and 750ms. + """ + wait = min(10, attempt) + jitter = uniform(0.05, 0.75) + return wait + jitter + + +def poll( + fn: Callable[[], Tuple[Optional[T], Optional[RetryError]]], + timeout: timedelta = timedelta(minutes=20), + clock: Optional[Clock] = None, +) -> T: + """Poll a function until it succeeds or times out. + + The backoff is linear backoff and jitter. + + This function is not meant to be used directly by users. + It is used internally by the SDK to poll for the result of an operation. + It can be changed in the future without any notice. + + :param fn: Function that returns (result, error). + Return (None, RetryError.continues("msg")) to continue polling. + Return (None, RetryError.halt(err)) to stop with error. + Return (result, None) on success. + :param timeout: Maximum time to poll (default: 20 minutes) + :param clock: Clock implementation for testing (default: RealClock) + :returns: The result of the successful function call + :raises TimeoutError: If the timeout is reached + :raises Exception: If a halting error is encountered + + Example: + def check_operation(): + op = get_operation() + if not op.done: + return None, RetryError.continues("operation still in progress") + if op.error: + return None, RetryError.halt(Exception(f"operation failed: {op.error}")) + return op.result, None + + result = poll(check_operation, timeout=timedelta(minutes=5)) + """ + if clock is None: + clock = RealClock() + + deadline = clock.time() + timeout.total_seconds() + attempt = 0 + last_err = None + + while clock.time() < deadline: + attempt += 1 + + try: + result, err = fn() + + if err is None: + return result + + if err.halt: + raise err.err + + # Continue polling. + last_err = err.err + wait = _backoff(attempt) + logger.debug(f"{str(err.err).rstrip('.')}. Sleeping {wait:.3f}s") + clock.sleep(wait) + + except RetryError: + raise + except Exception as e: + # Unexpected error, halt immediately. + raise e + + raise TimeoutError(f"Timed out after {timeout}") from last_err diff --git a/tests/test_retries.py b/tests/test_retries.py index 2ad6e4ef6..3fc97114d 100644 --- a/tests/test_retries.py +++ b/tests/test_retries.py @@ -1,9 +1,10 @@ from datetime import timedelta +from typing import Any, Literal, Optional, Tuple, Type import pytest from databricks.sdk.errors import NotFound, ResourceDoesNotExist -from databricks.sdk.retries import retried +from databricks.sdk.retries import RetryError, poll, retried from tests.clock import FakeClock @@ -73,3 +74,222 @@ def foo(): raise KeyError(1) foo() + + +@pytest.mark.parametrize( + "scenario,attempts,result_value,exception_type,exception_msg,timeout,min_time,max_time", + [ + pytest.param( + "success", + 1, + "immediate", + None, + None, + 60, + 0.0, + 0.0, + id="returns string immediately on first attempt with no sleep", + ), + pytest.param("success", 2, 42, None, None, 60, 1.05, 1.75, id="returns integer after 1 retry with ~1s backoff"), + pytest.param( + "success", + 3, + {"key": "val"}, + None, + None, + 60, + 3.10, + 4.50, + id="returns dict after 2 retries with linear backoff (1s+2s)", + ), + pytest.param( + "success", + 5, + [1, 2], + None, + None, + 60, + 10.2, + 13.0, + id="returns list after 4 retries with linear backoff (1s+2s+3s+4s)", + ), + pytest.param( + "success", + 1, + None, + None, + None, + 60, + 0.0, + 0.0, + id="returns None as valid result immediately (None is acceptable)", + ), + pytest.param( + "success", 5, "ok", None, None, 200, 10.2, 13.0, id="verifies linear backoff increase over 4 retries" + ), + pytest.param( + "success", + 11, + "ok", + None, + None, + 200, + 55.5, + 62.5, + id="verifies linear backoff approaching 10s cap over 10 retries", + ), + pytest.param( + "success", 15, "ok", None, None, 200, 95.7, 105.5, id="verifies backoff is capped at 10s after 10th retry" + ), + pytest.param( + "timeout", + None, + None, + TimeoutError, + "Timed out after", + 1, + 1, + None, + id="raises TimeoutError after 1 second of continuous retries", + ), + pytest.param( + "timeout", + None, + None, + TimeoutError, + "Timed out after", + 5, + 5, + None, + id="raises TimeoutError after 5 seconds of continuous retries", + ), + pytest.param( + "timeout", + None, + None, + TimeoutError, + "Timed out after", + 15, + 15, + None, + id="raises TimeoutError after 15 seconds of continuous retries", + ), + pytest.param( + "halt", + 1, + None, + ValueError, + "halt error", + 60, + None, + None, + id="raises ValueError immediately when halt error on first attempt", + ), + pytest.param( + "halt", + 2, + None, + ValueError, + "halt error", + 60, + None, + None, + id="raises ValueError after 1 retry when halt error on second attempt", + ), + pytest.param( + "halt", + 3, + None, + ValueError, + "halt error", + 60, + None, + None, + id="raises ValueError after 2 retries when halt error on third attempt", + ), + pytest.param( + "unexpected", + 1, + None, + RuntimeError, + "unexpected", + 60, + None, + None, + id="raises RuntimeError immediately on unexpected exception", + ), + pytest.param( + "unexpected", + 3, + None, + RuntimeError, + "unexpected", + 60, + None, + None, + id="raises RuntimeError after 2 retries on unexpected exception", + ), + ], +) +def test_poll_behavior( + scenario: Literal["success", "timeout", "halt", "unexpected"], + attempts: Optional[int], + result_value: Any, + exception_type: Optional[Type[Exception]], + exception_msg: Optional[str], + timeout: int, + min_time: Optional[float], + max_time: Optional[float], +) -> None: + """ + Comprehensive test for poll function covering all scenarios: + - Success cases with various return types and retry counts + - Backoff timing behavior (linear increase, 10s cap) + - Timeout behavior + - Halting errors + - Unexpected exceptions + """ + clock: FakeClock = FakeClock() + call_count: int = 0 + + def fn() -> Tuple[Any, Optional[RetryError]]: + nonlocal call_count + call_count += 1 + + if scenario == "success": + if call_count < attempts: + return None, RetryError.continues(f"attempt {call_count}") + return result_value, None + + elif scenario == "timeout": + return None, RetryError.continues("retrying") + + elif scenario == "halt": + if call_count < attempts: + return None, RetryError.continues("retrying") + return None, RetryError.halt(ValueError(exception_msg)) + + elif scenario == "unexpected": + if call_count < attempts: + return None, RetryError.continues("retrying") + raise RuntimeError(exception_msg) + + if scenario == "success": + result: Any = poll(fn, timeout=timedelta(seconds=timeout), clock=clock) + assert result == result_value + assert call_count == attempts + if min_time is not None: + assert clock.time() >= min_time + if max_time is not None: + assert clock.time() <= max_time + else: + with pytest.raises(exception_type) as exc_info: + poll(fn, timeout=timedelta(seconds=timeout), clock=clock) + + assert exception_msg in str(exc_info.value) + assert call_count >= 1 + + if scenario == "timeout": + assert clock.time() >= min_time + elif scenario in ("halt", "unexpected"): + assert call_count == attempts