|
1 | 1 | import functools |
2 | 2 | import logging |
3 | 3 | from datetime import timedelta |
4 | | -from random import random |
5 | | -from typing import Callable, Optional, Sequence, Type |
| 4 | +from random import random, uniform |
| 5 | +from typing import Callable, Optional, Sequence, Tuple, Type, TypeVar |
6 | 6 |
|
7 | 7 | from .clock import Clock, RealClock |
8 | 8 |
|
9 | 9 | logger = logging.getLogger(__name__) |
10 | 10 |
|
| 11 | +T = TypeVar("T") |
| 12 | + |
11 | 13 |
|
12 | 14 | def retried( |
13 | 15 | *, |
@@ -67,3 +69,101 @@ def wrapper(*args, **kwargs): |
67 | 69 | return wrapper |
68 | 70 |
|
69 | 71 | return decorator |
| 72 | + |
| 73 | + |
| 74 | +class RetryError(Exception): |
| 75 | + """Error that can be returned from poll functions to control retry behavior.""" |
| 76 | + |
| 77 | + def __init__(self, err: Exception, halt: bool = False): |
| 78 | + self.err = err |
| 79 | + self.halt = halt |
| 80 | + super().__init__(str(err)) |
| 81 | + |
| 82 | + @staticmethod |
| 83 | + def continues(msg: str) -> "RetryError": |
| 84 | + """Create a non-halting retry error with a message.""" |
| 85 | + return RetryError(Exception(msg), halt=False) |
| 86 | + |
| 87 | + @staticmethod |
| 88 | + def halt(err: Exception) -> "RetryError": |
| 89 | + """Create a halting retry error.""" |
| 90 | + return RetryError(err, halt=True) |
| 91 | + |
| 92 | + |
| 93 | +def _backoff(attempt: int) -> float: |
| 94 | + """Calculate backoff time with jitter. |
| 95 | +
|
| 96 | + Linear backoff: attempt * 1 second, capped at 10 seconds |
| 97 | + Plus random jitter between 50ms and 750ms. |
| 98 | + """ |
| 99 | + wait = min(10, attempt) |
| 100 | + jitter = uniform(0.05, 0.75) |
| 101 | + return wait + jitter |
| 102 | + |
| 103 | + |
| 104 | +def poll( |
| 105 | + fn: Callable[[], Tuple[Optional[T], Optional[RetryError]]], |
| 106 | + timeout: timedelta = timedelta(minutes=20), |
| 107 | + clock: Optional[Clock] = None, |
| 108 | +) -> T: |
| 109 | + """Poll a function until it succeeds or times out. |
| 110 | +
|
| 111 | + The backoff is linear backoff and jitter. |
| 112 | +
|
| 113 | + This function is not meant to be used directly by users. |
| 114 | + It is used internally by the SDK to poll for the result of an operation. |
| 115 | + It can be changed in the future without any notice. |
| 116 | +
|
| 117 | + :param fn: Function that returns (result, error). |
| 118 | + Return (None, RetryError.continues("msg")) to continue polling. |
| 119 | + Return (None, RetryError.halt(err)) to stop with error. |
| 120 | + Return (result, None) on success. |
| 121 | + :param timeout: Maximum time to poll (default: 20 minutes) |
| 122 | + :param clock: Clock implementation for testing (default: RealClock) |
| 123 | + :returns: The result of the successful function call |
| 124 | + :raises TimeoutError: If the timeout is reached |
| 125 | + :raises Exception: If a halting error is encountered |
| 126 | +
|
| 127 | + Example: |
| 128 | + def check_operation(): |
| 129 | + op = get_operation() |
| 130 | + if not op.done: |
| 131 | + return None, RetryError.continues("operation still in progress") |
| 132 | + if op.error: |
| 133 | + return None, RetryError.halt(Exception(f"operation failed: {op.error}")) |
| 134 | + return op.result, None |
| 135 | +
|
| 136 | + result = poll(check_operation, timeout=timedelta(minutes=5)) |
| 137 | + """ |
| 138 | + if clock is None: |
| 139 | + clock = RealClock() |
| 140 | + |
| 141 | + deadline = clock.time() + timeout.total_seconds() |
| 142 | + attempt = 0 |
| 143 | + last_err = None |
| 144 | + |
| 145 | + while clock.time() < deadline: |
| 146 | + attempt += 1 |
| 147 | + |
| 148 | + try: |
| 149 | + result, err = fn() |
| 150 | + |
| 151 | + if err is None: |
| 152 | + return result |
| 153 | + |
| 154 | + if err.halt: |
| 155 | + raise err.err |
| 156 | + |
| 157 | + # Continue polling. |
| 158 | + last_err = err.err |
| 159 | + wait = _backoff(attempt) |
| 160 | + logger.debug(f"{str(err.err).rstrip('.')}. Sleeping {wait:.3f}s") |
| 161 | + clock.sleep(wait) |
| 162 | + |
| 163 | + except RetryError: |
| 164 | + raise |
| 165 | + except Exception as e: |
| 166 | + # Unexpected error, halt immediately. |
| 167 | + raise e |
| 168 | + |
| 169 | + raise TimeoutError(f"Timed out after {timeout}") from last_err |
0 commit comments