Skip to content

Commit 2da1c3f

Browse files
committed
fix(providers): 🐛 fix race condition in lock creation
Implement a thread-safe `_get_lock` method in auth providers to handle the retrieval and creation of refresh locks. This ensures that the `_refresh_locks` dictionary is modified under a master lock (`_locks_lock`), preventing Time-of-check to time-of-use (TOCTOU) bugs where multiple coroutines could simultaneously create duplicate locks for the same path.
1 parent f6e88ae commit 2da1c3f

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

src/rotator_library/providers/iflow_auth_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,14 @@ def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
428428

429429
return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
430430

431+
async def _get_lock(self, path: str) -> asyncio.Lock:
432+
# [FIX RACE CONDITION] Protect lock creation with a master lock
433+
# This prevents TOCTOU bug where multiple coroutines check and create simultaneously
434+
async with self._locks_lock:
435+
if path not in self._refresh_locks:
436+
self._refresh_locks[path] = asyncio.Lock()
437+
return self._refresh_locks[path]
438+
431439
def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool:
432440
"""Check if token is TRULY expired (past actual expiry, not just threshold).
433441

src/rotator_library/providers/qwen_auth_base.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,14 @@ def _is_token_expired(self, creds: Dict[str, Any]) -> bool:
273273
expiry_timestamp = creds.get("expiry_date", 0) / 1000
274274
return expiry_timestamp < time.time() + REFRESH_EXPIRY_BUFFER_SECONDS
275275

276+
async def _get_lock(self, path: str) -> asyncio.Lock:
277+
# [FIX RACE CONDITION] Protect lock creation with a master lock
278+
# This prevents TOCTOU bug where multiple coroutines check and create simultaneously
279+
async with self._locks_lock:
280+
if path not in self._refresh_locks:
281+
self._refresh_locks[path] = asyncio.Lock()
282+
return self._refresh_locks[path]
283+
276284
def _is_token_truly_expired(self, creds: Dict[str, Any]) -> bool:
277285
"""Check if token is TRULY expired (past actual expiry, not just threshold).
278286

0 commit comments

Comments
 (0)