Skip to content

Commit 0d2e9ce

Browse files
authored
Merge pull request #229 from aiokitchen/backoff-class
Backoff as a class
2 parents 3a78b12 + 7d47f9c commit 0d2e9ce

File tree

8 files changed

+353
-296
lines changed

8 files changed

+353
-296
lines changed

aiomisc/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from . import io, log
22
from ._context_vars import StrictContextVar
33
from .aggregate import aggregate, aggregate_async
4-
from .backoff import asyncbackoff, asyncretry
4+
from .backoff import asyncbackoff, asyncretry, Backoff, BackoffExecution
55
from .circuit_breaker import CircuitBreaker, CircuitBroken, cutout
66
from .context import Context, get_context
77
from .counters import Statistic, get_statistics
@@ -31,6 +31,8 @@
3131

3232

3333
__all__ = (
34+
"Backoff",
35+
"BackoffExecution",
3436
"CURRENT_ENTRYPOINT",
3537
"CircuitBreaker",
3638
"CircuitBroken",

aiomisc/backoff.py

Lines changed: 187 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
import asyncio
2-
import sys
2+
33
from functools import wraps
44
from typing import (
5-
Any, Callable, Coroutine, Optional, Tuple, Type, TypeVar, Union,
5+
Any, Callable, Coroutine, Optional, Tuple, Type, TypeVar, Union, Generic,
6+
ParamSpec
67
)
78

89
from .counters import Statistic
9-
from .timeout import timeout
10-
11-
12-
if sys.version_info >= (3, 10):
13-
from typing import ParamSpec
14-
else:
15-
from typing_extensions import ParamSpec
1610

1711

1812
Number = Union[int, float]
@@ -32,6 +26,179 @@ class RetryStatistic(BackoffStatistic):
3226
pass
3327

3428

29+
class Backoff:
30+
__slots__ = (
31+
"attempt_timeout",
32+
"deadline",
33+
"pause",
34+
"max_tries",
35+
"giveup",
36+
"exceptions",
37+
"statistic",
38+
)
39+
40+
def __init__(
41+
self,
42+
attempt_timeout: Optional[Number],
43+
deadline: Optional[Number],
44+
pause: Number = 0,
45+
exceptions: Tuple[Type[Exception], ...] = (),
46+
max_tries: Optional[int] = None,
47+
giveup: Optional[Callable[[Exception], bool]] = None,
48+
statistic_name: Optional[str] = None,
49+
statistic_class: Type[BackoffStatistic] = BackoffStatistic
50+
):
51+
if not pause:
52+
pause = 0
53+
elif pause < 0:
54+
raise ValueError("'pause' must be positive")
55+
56+
if attempt_timeout is not None and attempt_timeout < 0:
57+
raise ValueError("'attempt_timeout' must be positive or None")
58+
59+
if deadline is not None and deadline < 0:
60+
raise ValueError("'deadline' must be positive or None")
61+
62+
if max_tries is not None and max_tries < 1:
63+
raise ValueError("'max_retries' must be >= 1 or None")
64+
65+
if giveup is not None and not callable(giveup):
66+
raise ValueError("'giveup' must be a callable or None")
67+
68+
exceptions = tuple(exceptions) or ()
69+
exceptions += asyncio.TimeoutError,
70+
71+
self.attempt_timeout = attempt_timeout
72+
self.deadline = deadline
73+
self.pause = pause
74+
self.max_tries = max_tries
75+
self.giveup = giveup
76+
self.exceptions = exceptions
77+
self.statistic = statistic_class(statistic_name)
78+
79+
def prepare(
80+
self,
81+
func: Callable[P, Coroutine[Any, Any, T]]
82+
) -> "BackoffExecution[P, T]":
83+
return BackoffExecution(
84+
function=func,
85+
statistic=self.statistic,
86+
attempt_timeout=self.attempt_timeout,
87+
deadline=self.deadline,
88+
pause=self.pause,
89+
max_tries=self.max_tries,
90+
giveup=self.giveup,
91+
exceptions=self.exceptions,
92+
)
93+
94+
async def execute(
95+
self,
96+
func: Callable[P, Coroutine[Any, Any, T]],
97+
*args: P.args,
98+
**kwargs: P.kwargs
99+
) -> T:
100+
execution = self.prepare(func)
101+
return await execution(*args, **kwargs)
102+
103+
def __call__(
104+
self,
105+
func: Callable[P, Coroutine[Any, Any, T]]
106+
) -> Callable[P, Coroutine[Any, Any, T]]:
107+
if not asyncio.iscoroutinefunction(func):
108+
raise TypeError("Function must be a coroutine function")
109+
110+
@wraps(func)
111+
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
112+
return await self.execute(func, *args, **kwargs)
113+
114+
return wrapper
115+
116+
117+
class BackoffExecution(Generic[P, T]):
118+
__slots__ = (
119+
"attempt_timeout",
120+
"deadline",
121+
"exceptions",
122+
"function",
123+
"giveup",
124+
"last_exception",
125+
"max_tries",
126+
"pause",
127+
"statistic",
128+
"total_tries",
129+
)
130+
131+
def __init__(
132+
self,
133+
function: Callable[P, Coroutine[Any, Any, T]],
134+
statistic: BackoffStatistic,
135+
attempt_timeout: Optional[Number],
136+
deadline: Optional[Number],
137+
pause: Number = 0,
138+
exceptions: Tuple[Type[Exception], ...] = (),
139+
max_tries: Optional[int] = None,
140+
giveup: Optional[Callable[[Exception], bool]] = None,
141+
):
142+
self.function = function
143+
self.statistic = statistic
144+
self.attempt_timeout = attempt_timeout
145+
self.deadline = deadline
146+
self.pause = pause
147+
self.max_tries = max_tries
148+
self.giveup = giveup
149+
self.exceptions = exceptions
150+
151+
self.last_exception: Optional[Exception] = None
152+
self.total_tries: int = 0
153+
154+
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
155+
return await self.execute(*args, **kwargs)
156+
157+
async def execute(self, *args: P.args, **kwargs: P.kwargs) -> T:
158+
async def run() -> Any:
159+
loop = asyncio.get_running_loop()
160+
161+
while True:
162+
self.statistic.attempts += 1
163+
self.total_tries += 1
164+
delta = -loop.time()
165+
166+
try:
167+
return await asyncio.wait_for(
168+
self.function(*args, **kwargs),
169+
timeout=self.attempt_timeout,
170+
)
171+
except asyncio.CancelledError:
172+
self.statistic.cancels += 1
173+
raise
174+
except self.exceptions as e:
175+
self.statistic.errors += 1
176+
self.last_exception = e
177+
if (
178+
self.max_tries is not None and
179+
self.total_tries >= self.max_tries
180+
):
181+
raise
182+
183+
if self.giveup and self.giveup(e):
184+
raise
185+
await asyncio.sleep(self.pause)
186+
except Exception as e:
187+
self.last_exception = e
188+
raise
189+
finally:
190+
delta += loop.time()
191+
self.statistic.sum_time += delta
192+
self.statistic.done += 1
193+
194+
try:
195+
return await asyncio.wait_for(run(), timeout=self.deadline)
196+
except Exception:
197+
if self.last_exception is not None:
198+
raise self.last_exception
199+
raise
200+
201+
35202
# noinspection SpellCheckingInspection
36203
def asyncbackoff(
37204
attempt_timeout: Optional[Number],
@@ -58,90 +225,23 @@ def asyncbackoff(
58225
execution attempt.
59226
:param deadline: is maximum execution time for all execution attempts.
60227
:param pause: is time gap between execution attempts.
61-
:param exc: retrying when this exceptions was raised.
228+
:param exc: retrying when these exceptions were raised.
62229
:param exceptions: similar as exc but keyword only.
63230
:param max_tries: is maximum count of execution attempts (>= 1).
64231
:param giveup: is a predicate function which can decide by a given
65232
:param statistic_class: statistic class
66233
"""
67234

68-
exceptions = exc + tuple(exceptions)
69-
statistic = statistic_class(statistic_name)
70-
71-
if not pause:
72-
pause = 0
73-
elif pause < 0:
74-
raise ValueError("'pause' must be positive")
75-
76-
if attempt_timeout is not None and attempt_timeout < 0:
77-
raise ValueError("'attempt_timeout' must be positive or None")
78-
79-
if deadline is not None and deadline < 0:
80-
raise ValueError("'deadline' must be positive or None")
81-
82-
if max_tries is not None and max_tries < 1:
83-
raise ValueError("'max_retries' must be >= 1 or None")
84-
85-
if giveup is not None and not callable(giveup):
86-
raise ValueError("'giveup' must be a callable or None")
87-
88-
exceptions = tuple(exceptions) or ()
89-
exceptions += asyncio.TimeoutError,
90-
91-
def decorator(
92-
func: Callable[P, Coroutine[Any, Any, T]],
93-
) -> Callable[P, Coroutine[Any, Any, T]]:
94-
if attempt_timeout is not None:
95-
func = timeout(attempt_timeout)(func)
96-
97-
@wraps(func)
98-
async def wrap(*args: P.args, **kwargs: P.kwargs) -> T:
99-
last_exc = None
100-
tries = 0
101-
102-
async def run() -> Any:
103-
nonlocal last_exc, tries
104-
105-
loop = asyncio.get_running_loop()
106-
107-
while True:
108-
statistic.attempts += 1
109-
tries += 1
110-
delta = -loop.time()
111-
112-
try:
113-
return await asyncio.wait_for(
114-
func(*args, **kwargs),
115-
timeout=attempt_timeout,
116-
)
117-
except asyncio.CancelledError:
118-
statistic.cancels += 1
119-
raise
120-
except exceptions as e:
121-
statistic.errors += 1
122-
last_exc = e
123-
if max_tries is not None and tries >= max_tries:
124-
raise
125-
if giveup and giveup(e):
126-
raise
127-
await asyncio.sleep(pause)
128-
except Exception as e:
129-
last_exc = e
130-
raise
131-
finally:
132-
delta += loop.time()
133-
statistic.sum_time += delta
134-
statistic.done += 1
135-
136-
try:
137-
return await asyncio.wait_for(run(), timeout=deadline)
138-
except Exception:
139-
if last_exc:
140-
raise last_exc
141-
raise
142-
143-
return wrap
144-
return decorator
235+
return Backoff(
236+
attempt_timeout=attempt_timeout,
237+
deadline=deadline,
238+
pause=pause,
239+
exceptions=exceptions or exc,
240+
max_tries=max_tries,
241+
giveup=giveup,
242+
statistic_name=statistic_name,
243+
statistic_class=statistic_class,
244+
)
145245

146246

147247
def asyncretry(

aiomisc/compat.py

Lines changed: 6 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22
import logging
33
import os
44
import socket
5-
import sys
6-
from typing import Any, Iterator, Optional
5+
from typing import (
6+
Any, Iterator, Optional, final, TypeAlias, ParamSpec,
7+
Protocol
8+
)
79

810
from ._context_vars import EVENT_LOOP
911

10-
1112
log = logging.getLogger(__name__)
1213

1314
try:
@@ -18,29 +19,6 @@
1819
def time_ns() -> int:
1920
return int(time() * 1000000000)
2021

21-
try:
22-
from typing import final
23-
except ImportError:
24-
from typing_extensions import final # type: ignore
25-
26-
27-
try:
28-
from typing import TypeAlias
29-
except ImportError:
30-
from typing_extensions import TypeAlias
31-
32-
33-
if sys.version_info >= (3, 10):
34-
from typing import ParamSpec
35-
else:
36-
from typing_extensions import ParamSpec
37-
38-
39-
if sys.version_info >= (3, 8):
40-
from typing import Protocol
41-
else:
42-
from typing_extensions import Protocol
43-
4422

4523
class EntrypointProtocol(Protocol):
4624
@property
@@ -77,12 +55,13 @@ class EventLoopMixin:
7755
def loop(self) -> asyncio.AbstractEventLoop:
7856
if not getattr(self, "_loop", None):
7957
self._loop = asyncio.get_running_loop()
80-
return self._loop # type: ignore
58+
return self._loop # type: ignore
8159

8260

8361
event_loop_policy: asyncio.AbstractEventLoopPolicy
8462
try:
8563
import uvloop
64+
8665
if (
8766
os.getenv("AIOMISC_USE_UVLOOP", "1").lower() in
8867
("yes", "1", "enabled", "enable", "on", "true")
@@ -93,7 +72,6 @@ def loop(self) -> asyncio.AbstractEventLoop:
9372
except ImportError:
9473
event_loop_policy = asyncio.DefaultEventLoopPolicy()
9574

96-
9775
if hasattr(socket, "TCP_NODELAY"):
9876
def sock_set_nodelay(sock: socket.socket) -> None:
9977
if sock.proto != socket.IPPROTO_TCP:

0 commit comments

Comments
 (0)