Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 102 additions & 2 deletions databricks/sdk/retries.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import functools
import logging
from datetime import timedelta
from random import random
from typing import Callable, Optional, Sequence, Type
from random import random, uniform
from typing import Callable, Optional, Sequence, Tuple, Type, TypeVar

from .clock import Clock, RealClock

logger = logging.getLogger(__name__)

T = TypeVar("T")


def retried(
*,
Expand Down Expand Up @@ -67,3 +69,101 @@ def wrapper(*args, **kwargs):
return wrapper

return decorator


class RetryError(Exception):
"""Error that can be returned from poll functions to control retry behavior."""

def __init__(self, err: Exception, halt: bool = False):
self.err = err
self.halt = halt
super().__init__(str(err))

@staticmethod
def continues(msg: str) -> "RetryError":
"""Create a non-halting retry error with a message."""
return RetryError(Exception(msg), halt=False)

@staticmethod
def halt(err: Exception) -> "RetryError":
"""Create a halting retry error."""
return RetryError(err, halt=True)


def _backoff(attempt: int) -> float:
"""Calculate backoff time with jitter.

Linear backoff: attempt * 1 second, capped at 10 seconds
Plus random jitter between 50ms and 750ms.
"""
wait = min(10, attempt)
jitter = uniform(0.05, 0.75)
return wait + jitter


def poll(
fn: Callable[[], Tuple[Optional[T], Optional[RetryError]]],
timeout: timedelta = timedelta(minutes=20),
clock: Optional[Clock] = None,
) -> T:
"""Poll a function until it succeeds or times out.

The backoff is linear backoff and jitter.

This function is not meant to be used directly by users.
It is used internally by the SDK to poll for the result of an operation.
It can be changed in the future without any notice.

:param fn: Function that returns (result, error).
Return (None, RetryError.continues("msg")) to continue polling.
Return (None, RetryError.halt(err)) to stop with error.
Return (result, None) on success.
:param timeout: Maximum time to poll (default: 20 minutes)
:param clock: Clock implementation for testing (default: RealClock)
:returns: The result of the successful function call
:raises TimeoutError: If the timeout is reached
:raises Exception: If a halting error is encountered

Example:
def check_operation():
op = get_operation()
if not op.done:
return None, RetryError.continues("operation still in progress")
if op.error:
return None, RetryError.halt(Exception(f"operation failed: {op.error}"))
return op.result, None

result = poll(check_operation, timeout=timedelta(minutes=5))
"""
if clock is None:
clock = RealClock()

deadline = clock.time() + timeout.total_seconds()
attempt = 0
last_err = None

while clock.time() < deadline:
attempt += 1

try:
result, err = fn()

if err is None:
return result

if err.halt:
raise err.err

# Continue polling.
last_err = err.err
wait = _backoff(attempt)
logger.debug(f"{str(err.err).rstrip('.')}. Sleeping {wait:.3f}s")
clock.sleep(wait)

except RetryError:
raise
except Exception as e:
# Unexpected error, halt immediately.
raise e

raise TimeoutError(f"Timed out after {timeout}") from last_err
222 changes: 221 additions & 1 deletion tests/test_retries.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from datetime import timedelta
from typing import Any, Literal, Optional, Tuple, Type

import pytest

from databricks.sdk.errors import NotFound, ResourceDoesNotExist
from databricks.sdk.retries import retried
from databricks.sdk.retries import RetryError, poll, retried
from tests.clock import FakeClock


Expand Down Expand Up @@ -73,3 +74,222 @@ def foo():
raise KeyError(1)

foo()


@pytest.mark.parametrize(
"scenario,attempts,result_value,exception_type,exception_msg,timeout,min_time,max_time",
[
pytest.param(
"success",
1,
"immediate",
None,
None,
60,
0.0,
0.0,
id="returns string immediately on first attempt with no sleep",
),
pytest.param("success", 2, 42, None, None, 60, 1.05, 1.75, id="returns integer after 1 retry with ~1s backoff"),
pytest.param(
"success",
3,
{"key": "val"},
None,
None,
60,
3.10,
4.50,
id="returns dict after 2 retries with linear backoff (1s+2s)",
),
pytest.param(
"success",
5,
[1, 2],
None,
None,
60,
10.2,
13.0,
id="returns list after 4 retries with linear backoff (1s+2s+3s+4s)",
),
pytest.param(
"success",
1,
None,
None,
None,
60,
0.0,
0.0,
id="returns None as valid result immediately (None is acceptable)",
),
pytest.param(
"success", 5, "ok", None, None, 200, 10.2, 13.0, id="verifies linear backoff increase over 4 retries"
),
pytest.param(
"success",
11,
"ok",
None,
None,
200,
55.5,
62.5,
id="verifies linear backoff approaching 10s cap over 10 retries",
),
pytest.param(
"success", 15, "ok", None, None, 200, 95.7, 105.5, id="verifies backoff is capped at 10s after 10th retry"
),
pytest.param(
"timeout",
None,
None,
TimeoutError,
"Timed out after",
1,
1,
None,
id="raises TimeoutError after 1 second of continuous retries",
),
pytest.param(
"timeout",
None,
None,
TimeoutError,
"Timed out after",
5,
5,
None,
id="raises TimeoutError after 5 seconds of continuous retries",
),
pytest.param(
"timeout",
None,
None,
TimeoutError,
"Timed out after",
15,
15,
None,
id="raises TimeoutError after 15 seconds of continuous retries",
),
pytest.param(
"halt",
1,
None,
ValueError,
"halt error",
60,
None,
None,
id="raises ValueError immediately when halt error on first attempt",
),
pytest.param(
"halt",
2,
None,
ValueError,
"halt error",
60,
None,
None,
id="raises ValueError after 1 retry when halt error on second attempt",
),
pytest.param(
"halt",
3,
None,
ValueError,
"halt error",
60,
None,
None,
id="raises ValueError after 2 retries when halt error on third attempt",
),
pytest.param(
"unexpected",
1,
None,
RuntimeError,
"unexpected",
60,
None,
None,
id="raises RuntimeError immediately on unexpected exception",
),
pytest.param(
"unexpected",
3,
None,
RuntimeError,
"unexpected",
60,
None,
None,
id="raises RuntimeError after 2 retries on unexpected exception",
),
],
)
def test_poll_behavior(
scenario: Literal["success", "timeout", "halt", "unexpected"],
attempts: Optional[int],
result_value: Any,
exception_type: Optional[Type[Exception]],
exception_msg: Optional[str],
timeout: int,
min_time: Optional[float],
max_time: Optional[float],
) -> None:
"""
Comprehensive test for poll function covering all scenarios:
- Success cases with various return types and retry counts
- Backoff timing behavior (linear increase, 10s cap)
- Timeout behavior
- Halting errors
- Unexpected exceptions
"""
clock: FakeClock = FakeClock()
call_count: int = 0

def fn() -> Tuple[Any, Optional[RetryError]]:
nonlocal call_count
call_count += 1

if scenario == "success":
if call_count < attempts:
return None, RetryError.continues(f"attempt {call_count}")
return result_value, None

elif scenario == "timeout":
return None, RetryError.continues("retrying")

elif scenario == "halt":
if call_count < attempts:
return None, RetryError.continues("retrying")
return None, RetryError.halt(ValueError(exception_msg))

elif scenario == "unexpected":
if call_count < attempts:
return None, RetryError.continues("retrying")
raise RuntimeError(exception_msg)

if scenario == "success":
result: Any = poll(fn, timeout=timedelta(seconds=timeout), clock=clock)
assert result == result_value
assert call_count == attempts
if min_time is not None:
assert clock.time() >= min_time
if max_time is not None:
assert clock.time() <= max_time
else:
with pytest.raises(exception_type) as exc_info:
poll(fn, timeout=timedelta(seconds=timeout), clock=clock)

assert exception_msg in str(exc_info.value)
assert call_count >= 1

if scenario == "timeout":
assert clock.time() >= min_time
elif scenario in ("halt", "unexpected"):
assert call_count == attempts
Loading