|
9 | 9 | from collections.abc import Awaitable, Callable |
10 | 10 | from datetime import timedelta |
11 | 11 | from itertools import chain |
| 12 | +from typing import Final |
12 | 13 | from unittest.mock import Mock |
13 | 14 |
|
14 | 15 | import arrow |
|
22 | 23 | start_exclusive_periodic_task, |
23 | 24 | ) |
24 | 25 | from servicelib.redis._errors import LockLostError |
25 | | -from servicelib.utils import logged_gather |
| 26 | +from servicelib.utils import limited_gather, logged_gather |
26 | 27 | from tenacity.asyncio import AsyncRetrying |
27 | 28 | from tenacity.retry import retry_if_exception_type |
28 | 29 | from tenacity.stop import stop_after_delay |
@@ -332,46 +333,41 @@ async def _() -> None: |
332 | 333 | assert await redis_client_sdk.lock_value(lock_name) is None |
333 | 334 |
|
334 | 335 |
|
335 | | -# async def test_lock_acquired_in_parallel_to_update_same_resource( |
336 | | -# with_short_default_redis_lock_ttl: None, |
337 | | -# get_redis_client_sdk: Callable[ |
338 | | -# [RedisDatabase], AbstractAsyncContextManager[RedisClientSDK] |
339 | | -# ], |
340 | | -# faker: Faker, |
341 | | -# ): |
342 | | -# INCREASE_OPERATIONS: Final[int] = 250 |
343 | | -# INCREASE_BY: Final[int] = 10 |
344 | | - |
345 | | -# class RaceConditionCounter: |
346 | | -# def __init__(self): |
347 | | -# self.value: int = 0 |
348 | | - |
349 | | -# async def race_condition_increase(self, by: int) -> None: |
350 | | -# current_value = self.value |
351 | | -# current_value += by |
352 | | -# # most likely situation which creates issues |
353 | | -# await asyncio.sleep(redis_constants.DEFAULT_LOCK_TTL.total_seconds() / 2) |
354 | | -# self.value = current_value |
355 | | - |
356 | | -# counter = RaceConditionCounter() |
357 | | -# lock_name: str = faker.pystr() |
358 | | -# # ensures it does nto time out before acquiring the lock |
359 | | -# time_for_all_inc_counter_calls_to_finish_s: float = ( |
360 | | -# redis_constants.DEFAULT_LOCK_TTL.total_seconds() * INCREASE_OPERATIONS * 10 |
361 | | -# ) |
362 | | - |
363 | | -# async def _inc_counter() -> None: |
364 | | -# async with get_redis_client_sdk( |
365 | | -# RedisDatabase.RESOURCES |
366 | | -# ) as redis_client_sdk: |
367 | | -# async with redis_client_sdk.lock_context( |
368 | | -# lock_key=lock_name, |
369 | | -# blocking=True, |
370 | | -# blocking_timeout_s=time_for_all_inc_counter_calls_to_finish_s, |
371 | | -# ): |
372 | | -# await counter.race_condition_increase(INCREASE_BY) |
373 | | - |
374 | | -# await limited_gather( |
375 | | -# *(_inc_counter() for _ in range(INCREASE_OPERATIONS)), limit=15 |
376 | | -# ) |
377 | | -# assert counter.value == INCREASE_BY * INCREASE_OPERATIONS |
| 336 | +async def test_lock_acquired_in_parallel_to_update_same_resource( |
| 337 | + with_short_default_redis_lock_ttl: datetime.timedelta, |
| 338 | + redis_client_sdk: RedisClientSDK, |
| 339 | + lock_name: str, |
| 340 | +): |
| 341 | + INCREASE_OPERATIONS: Final[int] = 250 |
| 342 | + INCREASE_BY: Final[int] = 10 |
| 343 | + |
| 344 | + class RaceConditionCounter: |
| 345 | + def __init__(self) -> None: |
| 346 | + self.value: int = 0 |
| 347 | + |
| 348 | + async def race_condition_increase(self, by: int) -> None: |
| 349 | + current_value = self.value |
| 350 | + current_value += by |
| 351 | + # most likely situation which creates issues |
| 352 | + await asyncio.sleep(with_short_default_redis_lock_ttl.total_seconds() / 2) |
| 353 | + self.value = current_value |
| 354 | + |
| 355 | + counter = RaceConditionCounter() |
| 356 | + # ensures it does nto time out before acquiring the lock |
| 357 | + time_for_all_inc_counter_calls_to_finish = ( |
| 358 | + with_short_default_redis_lock_ttl * INCREASE_OPERATIONS * 10 |
| 359 | + ) |
| 360 | + |
| 361 | + @exclusive( |
| 362 | + redis_client_sdk, |
| 363 | + lock_key=lock_name, |
| 364 | + blocking=True, |
| 365 | + blocking_timeout=time_for_all_inc_counter_calls_to_finish, |
| 366 | + ) |
| 367 | + async def _inc_counter() -> None: |
| 368 | + await counter.race_condition_increase(INCREASE_BY) |
| 369 | + |
| 370 | + await limited_gather( |
| 371 | + *(_inc_counter() for _ in range(INCREASE_OPERATIONS)), limit=15 |
| 372 | + ) |
| 373 | + assert counter.value == INCREASE_BY * INCREASE_OPERATIONS |
0 commit comments