Skip to content

Commit 5ca3f4a

Browse files
author
Andrei Neagu
committed
add task notification utility
1 parent db3aacf commit 5ca3f4a

File tree

2 files changed

+107
-10
lines changed

2 files changed

+107
-10
lines changed

packages/service-library/src/servicelib/async_utils.py

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import asyncio
22
import logging
3-
from collections import deque
3+
from collections.abc import Awaitable, Callable
44
from contextlib import suppress
55
from dataclasses import dataclass
6+
from datetime import timedelta
67
from functools import wraps
7-
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Deque
8+
from typing import TYPE_CHECKING, Any, Coroutine, Final, TypeVar
9+
10+
from pydantic import NonNegativeFloat
11+
from servicelib.background_task import cancel_task
812

913
from .utils_profiling_middleware import dont_profile, is_profiling, profile_context
1014

11-
logger = logging.getLogger(__name__)
15+
_logger = logging.getLogger(__name__)
1216

1317
if TYPE_CHECKING:
1418
Queue = asyncio.Queue
@@ -51,7 +55,7 @@ async def _safe_cancel(context: Context) -> None:
5155
await context.task
5256
except RuntimeError as e:
5357
if "Event loop is closed" in f"{e}":
54-
logger.warning("event loop is closed and could not cancel %s", context)
58+
_logger.warning("event loop is closed and could not cancel %s", context)
5559
else:
5660
raise
5761

@@ -62,7 +66,7 @@ async def cancel_sequential_workers() -> None:
6266
await _safe_cancel(context)
6367

6468
_sequential_jobs_contexts.clear()
65-
logger.info("All run_sequentially_in_context pending workers stopped")
69+
_logger.info("All run_sequentially_in_context pending workers stopped")
6670

6771

6872
# NOTE: If you get funny mismatches with mypy in returned values it might be due to this decorator.
@@ -118,22 +122,22 @@ def _get_context(args: Any, kwargs: dict) -> Context:
118122
search_args = dict(zip(arg_names, args))
119123
search_args.update(kwargs)
120124

121-
key_parts: Deque[str] = deque()
125+
key_parts: list[str] = []
122126
for arg in target_args:
123127
sub_args = arg.split(".")
124128
main_arg = sub_args[0]
125129
if main_arg not in search_args:
126-
raise ValueError(
130+
msg = (
127131
f"Expected '{main_arg}' in '{decorated_function.__name__}'"
128132
f" arguments. Got '{search_args}'"
129133
)
134+
raise ValueError(msg)
130135
context_key = search_args[main_arg]
131136
for attribute in sub_args[1:]:
132137
potential_key = getattr(context_key, attribute)
133138
if not potential_key:
134-
raise ValueError(
135-
f"Expected '{attribute}' attribute in '{context_key.__name__}' arguments."
136-
)
139+
msg = f"Expected '{attribute}' attribute in '{context_key.__name__}' arguments."
140+
raise ValueError(msg)
137141
context_key = potential_key
138142

139143
key_parts.append(f"{decorated_function.__name__}_{context_key}")
@@ -200,3 +204,35 @@ async def worker(in_q: Queue[QueueElement], out_q: Queue) -> None:
200204
return wrapper
201205

202206
return decorator
207+
208+
209+
T = TypeVar("T")
210+
_CANCELLATION_TIMEOUT: Final[NonNegativeFloat] = 0.1
211+
212+
213+
async def _monitor_task(
214+
notification_hook: Callable[[], Awaitable[None]], notify_after: timedelta
215+
) -> None:
216+
await asyncio.sleep(notify_after.total_seconds())
217+
await notification_hook()
218+
219+
220+
async def notify_when_over_threshold(
221+
task: Coroutine[Any, Any, T],
222+
*,
223+
notification_hook: Callable[[], Awaitable[None]],
224+
notify_after: timedelta,
225+
) -> T:
226+
monitor_task = asyncio.create_task(_monitor_task(notification_hook, notify_after))
227+
228+
try:
229+
result = await task
230+
await cancel_task(monitor_task, timeout=_CANCELLATION_TIMEOUT)
231+
except asyncio.CancelledError:
232+
await cancel_task(monitor_task, timeout=_CANCELLATION_TIMEOUT)
233+
raise
234+
except Exception:
235+
await cancel_task(monitor_task, timeout=_CANCELLATION_TIMEOUT)
236+
raise
237+
238+
return result

packages/service-library/tests/test_async_utils.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77
import random
88
from collections import deque
99
from dataclasses import dataclass
10+
from datetime import timedelta
1011
from time import time
1112
from typing import Any
13+
from unittest.mock import Mock
1214

1315
import pytest
1416
from faker import Faker
1517
from servicelib.async_utils import (
1618
_sequential_jobs_contexts,
19+
notify_when_over_threshold,
1720
run_sequentially_in_context,
1821
)
1922

@@ -224,3 +227,61 @@ async def test_multiple_context_calls(context_param: int) -> int:
224227
assert i == await test_multiple_context_calls(i)
225228

226229
assert len(_sequential_jobs_contexts) == RETRIES
230+
231+
232+
async def test_notify_when_over_threshold():
233+
234+
notification_spy = Mock()
235+
236+
async def notification() -> None:
237+
notification_spy()
238+
print("notified")
239+
240+
async def _worker(*, sleep_for: float, raise_error: bool = False) -> float:
241+
await asyncio.sleep(sleep_for)
242+
243+
if raise_error:
244+
msg = "raise error as planned"
245+
raise RuntimeError(msg)
246+
247+
return sleep_for
248+
249+
# 1. finish after
250+
result = await notify_when_over_threshold(
251+
_worker(sleep_for=0.5),
252+
notification_hook=notification,
253+
notify_after=timedelta(seconds=0.1),
254+
)
255+
assert isinstance(result, float)
256+
assert result == 0.5
257+
assert notification_spy.call_count == 1
258+
259+
# 2. finish before
260+
notification_spy.reset_mock()
261+
await notify_when_over_threshold(
262+
_worker(sleep_for=0.1),
263+
notification_hook=notification,
264+
notify_after=timedelta(seconds=0.2),
265+
)
266+
await asyncio.sleep(0.2)
267+
assert notification_spy.call_count == 0
268+
269+
# 3. raise error before notification
270+
notification_spy.reset_mock()
271+
with pytest.raises(RuntimeError):
272+
await notify_when_over_threshold(
273+
_worker(sleep_for=0, raise_error=True),
274+
notification_hook=notification,
275+
notify_after=timedelta(seconds=0.2),
276+
)
277+
assert notification_spy.call_count == 0
278+
279+
# 4. raise after notification
280+
notification_spy.reset_mock()
281+
with pytest.raises(RuntimeError):
282+
await notify_when_over_threshold(
283+
_worker(sleep_for=0.2, raise_error=True),
284+
notification_hook=notification,
285+
notify_after=timedelta(seconds=0.1),
286+
)
287+
assert notification_spy.call_count == 1

0 commit comments

Comments
 (0)