1- # pylint:disable=unused-variable
2- # pylint:disable=unused-argument
3- # pylint:disable=redefined-outer-name
4- # pylint:disable=protected-access
1+ # pylint: disable=no-value-for-parameter
2+ # pylint: disable=redefined-outer-name
3+ # pylint: disable=unused-argument
4+ # pylint: disable=unused-variable
55
66
77import asyncio
8- from datetime import timedelta
8+ import datetime
9+ from collections .abc import AsyncIterator , Callable
10+ from contextlib import AbstractAsyncContextManager
911from itertools import chain
10- from unittest . mock import Mock
12+ from unittest import mock
1113
1214import arrow
15+ import pytest
1316from servicelib .async_utils import cancel_wait_task
14- from servicelib .redis . _client import RedisClientSDK
15- from servicelib .redis . _distributed_locks_utils import create_exclusive_periodic_task
16- from servicelib . utils import logged_gather
17+ from servicelib .background_task_utils import exclusive_periodic
18+ from servicelib .redis import RedisClientSDK
19+ from settings_library . redis import RedisDatabase
1720from tenacity import (
1821 AsyncRetrying ,
1922 retry_if_exception_type ,
2932]
3033
3134
32- async def _sleep_task (sleep_interval : float , on_sleep_events : Mock ) -> None :
33- on_sleep_events (arrow .utcnow ())
34- await asyncio .sleep (sleep_interval )
35- print ("Slept for" , sleep_interval )
36- on_sleep_events (arrow .utcnow ())
35+ @pytest .fixture
36+ async def redis_client_sdk (
37+ get_redis_client_sdk : Callable [
38+ [RedisDatabase ], AbstractAsyncContextManager [RedisClientSDK ]
39+ ],
40+ ) -> AsyncIterator [RedisClientSDK ]:
41+ async with get_redis_client_sdk (RedisDatabase .RESOURCES ) as client :
42+ yield client
3743
3844
39- async def _assert_on_sleep_done (on_sleep_events : Mock , * , stop_after : float ):
45+ async def _assert_on_sleep_done (on_sleep_events : mock . Mock , * , stop_after : float ):
4046 async for attempt in AsyncRetrying (
4147 wait = wait_fixed (0.1 ),
4248 stop = stop_after_delay (stop_after ),
@@ -52,20 +58,20 @@ async def _assert_task_completes_once(
5258 redis_client_sdk : RedisClientSDK ,
5359 stop_after : float ,
5460) -> tuple [float , ...]:
55- sleep_events = Mock ( )
56-
57- started_task = create_exclusive_periodic_task (
58- redis_client_sdk ,
59- _sleep_task ,
60- task_period = timedelta ( seconds = 1 ),
61- task_name = "pytest_sleep_task" ,
62- sleep_interval = 1 ,
63- on_sleep_events = sleep_events ,
64- )
61+ @ exclusive_periodic ( redis_client_sdk , task_interval = datetime . timedelta ( seconds = 1 ) )
62+ async def _sleep_task ( sleep_interval : float , on_sleep_events : mock . Mock ) -> None :
63+ on_sleep_events ( arrow . utcnow ())
64+ await asyncio . sleep ( sleep_interval )
65+ print ( "Slept for" , sleep_interval )
66+ on_sleep_events ( arrow . utcnow ())
67+
68+ sleep_events = mock . Mock ()
69+
70+ task = asyncio . create_task ( _sleep_task ( 1 , sleep_events ), name = "pytest_sleep_task" )
6571
6672 await _assert_on_sleep_done (sleep_events , stop_after = stop_after )
6773
68- await cancel_wait_task (started_task , max_delay = 5 )
74+ await cancel_wait_task (task , max_delay = 5 )
6975
7076 events_timestamps : tuple [float , ...] = tuple (
7177 x .args [0 ].timestamp () for x in sleep_events .call_args_list
@@ -86,33 +92,33 @@ def test__check_elements_lower():
8692 assert not _check_elements_lower ([1 , 2 , 4 , 3 , 5 ])
8793
8894
89- async def test_create_exclusive_periodic_task_single (
95+ async def test_exclusive_periodic_decorator_single (
9096 redis_client_sdk : RedisClientSDK ,
9197):
9298 await _assert_task_completes_once (redis_client_sdk , stop_after = 2 )
9399
94100
95- async def test_create_exclusive_periodic_task_parallel_all_finish (
101+ async def test_exclusive_periodic_decorator_parallel_all_finish (
96102 redis_client_sdk : RedisClientSDK ,
97103):
98104 parallel_tasks = 10
99- results : list [ tuple [ float , float ]] = await logged_gather (
105+ results = await asyncio . gather (
100106 * [
101107 _assert_task_completes_once (redis_client_sdk , stop_after = 60 )
102108 for _ in range (parallel_tasks )
103109 ],
104- reraise = False ,
110+ return_exceptions = True ,
105111 )
106112
107113 # check no error occurred
108114 assert [isinstance (x , tuple ) for x in results ].count (True ) == parallel_tasks
109- assert [x [0 ] < x [1 ] for x in results ].count (True ) == parallel_tasks
115+ assert [isinstance (x , Exception ) for x in results ].count (True ) == 0
116+ valid_results = [x for x in results if isinstance (x , tuple )]
117+ assert [x [0 ] < x [1 ] for x in valid_results ].count (True ) == parallel_tasks
110118
111119 # sort by start time (task start order is not equal to the task lock acquisition order)
112- sorted_results : list [tuple [float , float ]] = sorted (results , key = lambda x : x [0 ])
113-
114- # pylint:disable=unnecessary-comprehension
115- flattened_results : list [float ] = [x for x in chain (* sorted_results )] # noqa: C416
120+ sorted_results = sorted (valid_results , key = lambda x : x [0 ])
121+ flattened_results = list (chain (* sorted_results ))
116122
117123 # NOTE all entries should be in increasing order;
118124 # this means that the `_sleep_task` ran sequentially
0 commit comments