55from contextlib import AbstractAsyncContextManager
66from datetime import timedelta
77from itertools import chain
8+ from typing import Awaitable
89from unittest .mock import Mock
910
1011import arrow
@@ -32,39 +33,117 @@ async def _is_locked(redis_client_sdk: RedisClientSDK, lock_name: str) -> bool:
3233
3334@pytest .fixture
3435def lock_name (faker : Faker ) -> str :
35- return faker .uuid4 () # type: ignore
36+ return faker .pystr ()
37+
38+
39+ def _exclusive_sleeping_task (
40+ redis_client_sdk : RedisClientSDK | Callable [..., RedisClientSDK ],
41+ lock_name : str | Callable [..., str ],
42+ sleep_duration : float ,
43+ ) -> Callable [..., Awaitable [float ]]:
44+ @exclusive (redis_client_sdk , lock_key = lock_name )
45+ async def _ () -> float :
46+ resolved_client = (
47+ redis_client_sdk () if callable (redis_client_sdk ) else redis_client_sdk
48+ )
49+ resolved_lock_name = lock_name () if callable (lock_name ) else lock_name
50+ assert await _is_locked (resolved_client , resolved_lock_name )
51+ await asyncio .sleep (sleep_duration )
52+ assert await _is_locked (resolved_client , resolved_lock_name )
53+ return sleep_duration
54+
55+ return _
56+
57+
58+ @pytest .fixture
59+ def sleep_duration (faker : Faker ) -> float :
60+ return faker .pyfloat (positive = True , min_value = 0.2 , max_value = 0.8 )
3661
3762
38- async def _contained_client (
63+ async def test_exclusive_decorator (
3964 get_redis_client_sdk : Callable [
4065 [RedisDatabase ], AbstractAsyncContextManager [RedisClientSDK ]
4166 ],
4267 lock_name : str ,
43- task_duration : float ,
44- ) -> None :
45- async with get_redis_client_sdk (RedisDatabase .RESOURCES ) as redis_client_sdk :
46- assert not await _is_locked (redis_client_sdk , lock_name )
47-
48- @exclusive (redis_client_sdk , lock_key = lock_name )
49- async def _some_task () -> None :
50- assert await _is_locked (redis_client_sdk , lock_name )
51- await asyncio .sleep (task_duration )
52- assert await _is_locked (redis_client_sdk , lock_name )
53-
54- await _some_task ()
68+ sleep_duration : float ,
69+ ):
5570
56- assert not await _is_locked (redis_client_sdk , lock_name )
71+ async with get_redis_client_sdk (RedisDatabase .RESOURCES ) as redis_client :
72+ for _ in range (3 ):
73+ assert (
74+ await _exclusive_sleeping_task (
75+ redis_client , lock_name , sleep_duration
76+ )()
77+ == sleep_duration
78+ )
5779
5880
59- @pytest .mark .parametrize ("task_duration" , [0.1 , 1 , 2 ])
60- async def test_exclusive_sequentially (
81+ async def test_exclusive_decorator_with_key_builder (
6182 get_redis_client_sdk : Callable [
6283 [RedisDatabase ], AbstractAsyncContextManager [RedisClientSDK ]
6384 ],
6485 lock_name : str ,
65- task_duration : float ,
86+ sleep_duration : float ,
6687):
67- await _contained_client (get_redis_client_sdk , lock_name , task_duration )
88+ def _get_lock_name (* args , ** kwargs ) -> str :
89+ assert args is not None
90+ assert kwargs is not None
91+ return lock_name
92+
93+ async with get_redis_client_sdk (RedisDatabase .RESOURCES ) as redis_client :
94+ for _ in range (3 ):
95+ assert (
96+ await _exclusive_sleeping_task (
97+ redis_client , _get_lock_name , sleep_duration
98+ )()
99+ == sleep_duration
100+ )
101+
102+
103+ async def test_exclusive_decorator_with_client_builder (
104+ get_redis_client_sdk : Callable [
105+ [RedisDatabase ], AbstractAsyncContextManager [RedisClientSDK ]
106+ ],
107+ lock_name : str ,
108+ sleep_duration : float ,
109+ ):
110+ async with get_redis_client_sdk (RedisDatabase .RESOURCES ) as redis_client :
111+
112+ def _get_redis_client_builder (* args , ** kwargs ) -> RedisClientSDK :
113+ assert args is not None
114+ assert kwargs is not None
115+ return redis_client
116+
117+ for _ in range (3 ):
118+ assert (
119+ await _exclusive_sleeping_task (
120+ _get_redis_client_builder , lock_name , sleep_duration
121+ )()
122+ == sleep_duration
123+ )
124+
125+
126+ async def _acquire_lock_and_exclusively_sleep (
127+ get_redis_client_sdk : Callable [
128+ [RedisDatabase ], AbstractAsyncContextManager [RedisClientSDK ]
129+ ],
130+ lock_name : str | Callable [..., str ],
131+ sleep_duration : float ,
132+ ) -> None :
133+ async with get_redis_client_sdk (RedisDatabase .RESOURCES ) as redis_client_sdk :
134+ redis_lock_name = lock_name () if callable (lock_name ) else lock_name
135+ assert not await _is_locked (redis_client_sdk , redis_lock_name )
136+
137+ @exclusive (redis_client_sdk , lock_key = lock_name )
138+ async def _ () -> float :
139+ assert await _is_locked (redis_client_sdk , redis_lock_name )
140+ await asyncio .sleep (sleep_duration )
141+ assert await _is_locked (redis_client_sdk , redis_lock_name )
142+ return sleep_duration
143+
144+ assert await _ () == sleep_duration
145+
146+ assert not await _is_locked (redis_client_sdk , redis_lock_name )
68147
69148
70149async def test_exclusive_parallel_lock_is_released_and_reacquired (
@@ -76,17 +155,19 @@ async def test_exclusive_parallel_lock_is_released_and_reacquired(
76155 parallel_tasks = 10
77156 results = await logged_gather (
78157 * [
79- _contained_client (get_redis_client_sdk , lock_name , task_duration = 0.1 )
158+ _acquire_lock_and_exclusively_sleep (
159+ get_redis_client_sdk , lock_name , sleep_duration = 0.1
160+ )
80161 for _ in range (parallel_tasks )
81162 ],
82- reraise = False
163+ reraise = False ,
83164 )
84165 assert results .count (None ) == 1
85166 assert [isinstance (x , CouldNotAcquireLockError ) for x in results ].count (
86167 True
87168 ) == parallel_tasks - 1
88169
89- # check lock is being released
170+ # check lock is released
90171 async with get_redis_client_sdk (RedisDatabase .RESOURCES ) as redis_client_sdk :
91172 assert not await _is_locked (redis_client_sdk , lock_name )
92173
@@ -168,7 +249,7 @@ async def test_start_exclusive_periodic_task_parallel_all_finish(
168249 _assert_task_completes_once (get_redis_client_sdk , stop_after = 60 )
169250 for _ in range (parallel_tasks )
170251 ],
171- reraise = False
252+ reraise = False ,
172253 )
173254
174255 # check no error occurred
0 commit comments