Skip to content

Commit 7c430fe

Browse files
committed
Add poll function for LRO which follows linear backoff with jitter.
1 parent 88f1047 commit 7c430fe

File tree

2 files changed

+204
-3
lines changed

2 files changed

+204
-3
lines changed

databricks/sdk/retries.py

Lines changed: 101 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import functools
22
import logging
33
from datetime import timedelta
4-
from random import random
5-
from typing import Callable, Optional, Sequence, Type
4+
from random import random, uniform
5+
from typing import Callable, Optional, Sequence, Type, TypeVar
66

77
from .clock import Clock, RealClock
88

99
logger = logging.getLogger(__name__)
1010

11+
T = TypeVar("T")
12+
1113

1214
def retried(
1315
*,
@@ -67,3 +69,100 @@ def wrapper(*args, **kwargs):
6769
return wrapper
6870

6971
return decorator
72+
73+
74+
class RetryError(Exception):
75+
"""Error that can be returned from poll functions to control retry behavior."""
76+
77+
def __init__(self, err: Exception, halt: bool = False):
78+
self.err = err
79+
self.halt = halt
80+
super().__init__(str(err))
81+
82+
@staticmethod
83+
def continues(msg: str) -> "RetryError":
84+
"""Create a non-halting retry error with a message."""
85+
return RetryError(Exception(msg), halt=False)
86+
87+
@staticmethod
88+
def halt(err: Exception) -> "RetryError":
89+
"""Create a halting retry error."""
90+
return RetryError(err, halt=True)
91+
92+
93+
def _backoff(attempt: int) -> float:
94+
"""Calculate backoff time with jitter.
95+
96+
Linear backoff: attempt * 1 second, capped at 10 seconds
97+
Plus random jitter between 50ms and 750ms.
98+
"""
99+
wait = min(10, attempt)
100+
jitter = uniform(0.05, 0.75)
101+
return wait + jitter
102+
103+
104+
# This function is not meant to be used directly by users.
105+
# It is used internally by the SDK to poll for the result of an operation.
106+
# It can be changed in the future without any notice.
107+
def poll(
108+
fn: Callable[[], tuple[Optional[T], Optional[RetryError]]],
109+
timeout: timedelta = timedelta(minutes=20),
110+
clock: Optional[Clock] = None,
111+
) -> T:
112+
"""Poll a function until it succeeds or times out.
113+
114+
The backoff is linear backoff and jitter.
115+
116+
:param fn: Function that returns (result, error).
117+
Return (None, RetryError.continues("msg")) to continue polling.
118+
Return (None, RetryError.halt(err)) to stop with error.
119+
Return (result, None) on success.
120+
:param timeout: Maximum time to poll (default: 20 minutes)
121+
:param clock: Clock implementation for testing (default: RealClock)
122+
:returns: The result of the successful function call
123+
:raises TimeoutError: If the timeout is reached
124+
:raises Exception: If a halting error is encountered
125+
126+
Example:
127+
def check_operation():
128+
op = get_operation()
129+
if not op.done:
130+
return None, RetryError.continues("operation still in progress")
131+
if op.error:
132+
return None, RetryError.halt(Exception(f"operation failed: {op.error}"))
133+
return op.result, None
134+
135+
result = poll(check_operation, timeout=timedelta(minutes=5))
136+
"""
137+
if clock is None:
138+
clock = RealClock()
139+
140+
deadline = clock.time() + timeout.total_seconds()
141+
attempt = 0
142+
last_err = None
143+
144+
while clock.time() < deadline:
145+
attempt += 1
146+
147+
try:
148+
result, err = fn()
149+
150+
if err is None:
151+
return result
152+
153+
if err.halt:
154+
raise err.err
155+
156+
# Continue polling.
157+
last_err = err.err
158+
wait = _backoff(attempt)
159+
logger.debug(f"{str(err.err).rstrip('.')}. Sleeping {wait:.3f}s")
160+
clock.sleep(wait)
161+
162+
except RetryError:
163+
raise
164+
except Exception as e:
165+
# Unexpected error, halt immediately.
166+
raise e
167+
168+
raise TimeoutError(f"Timed out after {timeout}") from last_err

tests/test_retries.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
from datetime import timedelta
2+
from typing import Any, Literal, Optional, Type
23

34
import pytest
45

56
from databricks.sdk.errors import NotFound, ResourceDoesNotExist
6-
from databricks.sdk.retries import retried
7+
from databricks.sdk.retries import poll, retried, RetryError
78
from tests.clock import FakeClock
89

910

@@ -73,3 +74,104 @@ def foo():
7374
raise KeyError(1)
7475

7576
foo()
77+
78+
79+
@pytest.mark.parametrize(
80+
"scenario,attempts,result_value,exception_type,exception_msg,timeout,min_time,max_time",
81+
[
82+
pytest.param("success", 1, "immediate", None, None, 60, 0.0, 0.0,
83+
id="returns string immediately on first attempt with no sleep"),
84+
pytest.param("success", 2, 42, None, None, 60, 1.05, 1.75,
85+
id="returns integer after 1 retry with ~1s backoff"),
86+
pytest.param("success", 3, {"key": "val"}, None, None, 60, 3.10, 3.90,
87+
id="returns dict after 2 retries with linear backoff (1s+2s)"),
88+
pytest.param("success", 5, [1, 2], None, None, 60, 10.25, 11.75,
89+
id="returns list after 4 retries with linear backoff (1s+2s+3s+4s)"),
90+
pytest.param("success", 1, None, None, None, 60, 0.0, 0.0,
91+
id="returns None as valid result immediately (None is acceptable)"),
92+
pytest.param("success", 5, "ok", None, None, 200, 10.2, 13.0,
93+
id="verifies linear backoff increase over 4 retries"),
94+
pytest.param("success", 11, "ok", None, None, 200, 55.5, 62.5,
95+
id="verifies linear backoff approaching 10s cap over 10 retries"),
96+
pytest.param("success", 15, "ok", None, None, 200, 95.7, 105.5,
97+
id="verifies backoff is capped at 10s after 10th retry"),
98+
pytest.param("timeout", None, None, TimeoutError, "Timed out after", 1, 1, None,
99+
id="raises TimeoutError after 1 second of continuous retries"),
100+
pytest.param("timeout", None, None, TimeoutError, "Timed out after", 5, 5, None,
101+
id="raises TimeoutError after 5 seconds of continuous retries"),
102+
pytest.param("timeout", None, None, TimeoutError, "Timed out after", 15, 15, None,
103+
id="raises TimeoutError after 15 seconds of continuous retries"),
104+
pytest.param("halt", 1, None, ValueError, "halt error", 60, None, None,
105+
id="raises ValueError immediately when halt error on first attempt"),
106+
pytest.param("halt", 2, None, ValueError, "halt error", 60, None, None,
107+
id="raises ValueError after 1 retry when halt error on second attempt"),
108+
pytest.param("halt", 3, None, ValueError, "halt error", 60, None, None,
109+
id="raises ValueError after 2 retries when halt error on third attempt"),
110+
pytest.param("unexpected", 1, None, RuntimeError, "unexpected", 60, None, None,
111+
id="raises RuntimeError immediately on unexpected exception"),
112+
pytest.param("unexpected", 3, None, RuntimeError, "unexpected", 60, None, None,
113+
id="raises RuntimeError after 2 retries on unexpected exception"),
114+
],
115+
)
116+
def test_poll_behavior(
117+
scenario: Literal["success", "timeout", "halt", "unexpected"],
118+
attempts: Optional[int],
119+
result_value: Any,
120+
exception_type: Optional[Type[Exception]],
121+
exception_msg: Optional[str],
122+
timeout: int,
123+
min_time: Optional[float],
124+
max_time: Optional[float],
125+
) -> None:
126+
"""
127+
Comprehensive test for poll function covering all scenarios:
128+
- Success cases with various return types and retry counts
129+
- Backoff timing behavior (linear increase, 10s cap)
130+
- Timeout behavior
131+
- Halting errors
132+
- Unexpected exceptions
133+
"""
134+
clock: FakeClock = FakeClock()
135+
call_count: int = 0
136+
137+
def fn() -> tuple[Any, Optional[RetryError]]:
138+
nonlocal call_count
139+
call_count += 1
140+
141+
if scenario == "success":
142+
if call_count < attempts:
143+
return None, RetryError.continues(f"attempt {call_count}")
144+
return result_value, None
145+
146+
elif scenario == "timeout":
147+
return None, RetryError.continues("retrying")
148+
149+
elif scenario == "halt":
150+
if call_count < attempts:
151+
return None, RetryError.continues("retrying")
152+
return None, RetryError.halt(ValueError(exception_msg))
153+
154+
elif scenario == "unexpected":
155+
if call_count < attempts:
156+
return None, RetryError.continues("retrying")
157+
raise RuntimeError(exception_msg)
158+
159+
if scenario == "success":
160+
result: Any = poll(fn, timeout=timedelta(seconds=timeout), clock=clock)
161+
assert result == result_value
162+
assert call_count == attempts
163+
if min_time is not None:
164+
assert clock.time() >= min_time
165+
if max_time is not None:
166+
assert clock.time() <= max_time
167+
else:
168+
with pytest.raises(exception_type) as exc_info:
169+
poll(fn, timeout=timedelta(seconds=timeout), clock=clock)
170+
171+
assert exception_msg in str(exc_info.value)
172+
assert call_count >= 1
173+
174+
if scenario == "timeout":
175+
assert clock.time() >= min_time - 1
176+
elif scenario in ("halt", "unexpected"):
177+
assert call_count == attempts

0 commit comments

Comments
 (0)