Skip to content

Commit 7f17a43

Browse files
authored
fix: Prevent race condition in concurrent storage creation (#1626)
### Description - Fix race condition when concurrent creating storage. ### Issues - Closes: #1621 ### Testing - Add new tests for `StorageInstanceManager`
1 parent 86faab9 commit 7f17a43

File tree

2 files changed

+164
-44
lines changed

2 files changed

+164
-44
lines changed

src/crawlee/storages/_storage_instance_manager.py

Lines changed: 103 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

3+
from asyncio import Lock
34
from collections import defaultdict
45
from collections.abc import Coroutine, Hashable
56
from dataclasses import dataclass, field
67
from typing import TYPE_CHECKING, TypeVar
8+
from weakref import WeakValueDictionary
79

810
from crawlee._utils.raise_if_too_many_kwargs import raise_if_too_many_kwargs
911
from crawlee.storage_clients._base import DatasetClient, KeyValueStoreClient, RequestQueueClient
@@ -76,6 +78,7 @@ class StorageInstanceManager:
7678

7779
def __init__(self) -> None:
7880
self._cache: _StorageCache = _StorageCache()
81+
self._opener_locks: WeakValueDictionary[tuple, Lock] = WeakValueDictionary()
7982

8083
async def open_storage_instance(
8184
self,
@@ -119,63 +122,71 @@ async def open_storage_instance(
119122
if not any([name, alias, id]):
120123
alias = self._DEFAULT_STORAGE_ALIAS
121124

122-
# Check cache
123-
if id is not None and (cached_instance := self._cache.by_id[cls][id].get(storage_client_cache_key)):
124-
if isinstance(cached_instance, cls):
125-
return cached_instance
126-
raise RuntimeError('Cached instance type mismatch.')
125+
# Check cache without lock first for performance.
126+
if cached_instance := self._get_from_cache(
127+
cls,
128+
id=id,
129+
name=name,
130+
alias=alias,
131+
storage_client_cache_key=storage_client_cache_key,
132+
):
133+
return cached_instance
127134

128-
if name is not None and (cached_instance := self._cache.by_name[cls][name].get(storage_client_cache_key)):
129-
if isinstance(cached_instance, cls):
130-
return cached_instance
131-
raise RuntimeError('Cached instance type mismatch.')
135+
# Validate storage name
136+
if name is not None:
137+
validate_storage_name(name)
132138

133-
if alias is not None and (
134-
cached_instance := self._cache.by_alias[cls][alias].get(storage_client_cache_key)
135-
):
136-
if isinstance(cached_instance, cls):
139+
# Acquire lock for this opener
140+
opener_lock_key = (cls, str(id or name or alias), storage_client_cache_key)
141+
if not (lock := self._opener_locks.get(opener_lock_key)):
142+
lock = Lock()
143+
self._opener_locks[opener_lock_key] = lock
144+
145+
async with lock:
146+
# Another task could have created the storage while we were waiting for the lock - check if that
147+
# happened
148+
if cached_instance := self._get_from_cache(
149+
cls,
150+
id=id,
151+
name=name,
152+
alias=alias,
153+
storage_client_cache_key=storage_client_cache_key,
154+
):
137155
return cached_instance
138-
raise RuntimeError('Cached instance type mismatch.')
139156

140-
# Check for conflicts between named and alias storages
141-
if alias and (self._cache.by_name[cls][alias].get(storage_client_cache_key)):
142-
raise ValueError(
143-
f'Cannot create alias storage "{alias}" because a named storage with the same name already exists. '
144-
f'Use a different alias or drop the existing named storage first.'
157+
# Check for conflicts between named and alias storages
158+
self._check_name_alias_conflict(
159+
cls,
160+
name=name,
161+
alias=alias,
162+
storage_client_cache_key=storage_client_cache_key,
145163
)
146164

147-
if name and (self._cache.by_alias[cls][name].get(storage_client_cache_key)):
148-
raise ValueError(
149-
f'Cannot create named storage "{name}" because an alias storage with the same name already exists. '
150-
f'Use a different name or drop the existing alias storage first.'
151-
)
165+
# Create new instance
166+
client: KeyValueStoreClient | DatasetClient | RequestQueueClient
167+
client = await client_opener_coro
152168

153-
# Validate storage name
154-
if name is not None:
155-
validate_storage_name(name)
156-
157-
# Create new instance
158-
client: KeyValueStoreClient | DatasetClient | RequestQueueClient
159-
client = await client_opener_coro
169+
metadata = await client.get_metadata()
160170

161-
metadata = await client.get_metadata()
171+
instance = cls(client, metadata.id, metadata.name) # type: ignore[call-arg]
172+
instance_name = getattr(instance, 'name', None)
162173

163-
instance = cls(client, metadata.id, metadata.name) # type: ignore[call-arg]
164-
instance_name = getattr(instance, 'name', None)
174+
# Cache the instance.
175+
# Note: No awaits in this section. All cache entries must be written
176+
# atomically to ensure pre-checks outside the lock see consistent state.
165177

166-
# Cache the instance.
167-
# Always cache by id.
168-
self._cache.by_id[cls][instance.id][storage_client_cache_key] = instance
178+
# Always cache by id.
179+
self._cache.by_id[cls][instance.id][storage_client_cache_key] = instance
169180

170-
# Cache named storage.
171-
if instance_name is not None:
172-
self._cache.by_name[cls][instance_name][storage_client_cache_key] = instance
181+
# Cache named storage.
182+
if instance_name is not None:
183+
self._cache.by_name[cls][instance_name][storage_client_cache_key] = instance
173184

174-
# Cache unnamed storage.
175-
if alias is not None:
176-
self._cache.by_alias[cls][alias][storage_client_cache_key] = instance
185+
# Cache unnamed storage.
186+
if alias is not None:
187+
self._cache.by_alias[cls][alias][storage_client_cache_key] = instance
177188

178-
return instance
189+
return instance
179190

180191
finally:
181192
# Make sure the client opener is closed.
@@ -193,3 +204,51 @@ def remove_from_cache(self, storage_instance: Storage) -> None:
193204
def clear_cache(self) -> None:
194205
"""Clear all cached storage instances."""
195206
self._cache = _StorageCache()
207+
208+
def _get_from_cache(
209+
self,
210+
cls: type[T],
211+
*,
212+
id: str | None = None,
213+
name: str | None = None,
214+
alias: str | None = None,
215+
storage_client_cache_key: Hashable = '',
216+
) -> T | None:
217+
"""Get a storage instance from the cache."""
218+
if id is not None and (cached_instance := self._cache.by_id[cls][id].get(storage_client_cache_key)):
219+
if isinstance(cached_instance, cls):
220+
return cached_instance
221+
raise RuntimeError('Cached instance type mismatch.')
222+
223+
if name is not None and (cached_instance := self._cache.by_name[cls][name].get(storage_client_cache_key)):
224+
if isinstance(cached_instance, cls):
225+
return cached_instance
226+
raise RuntimeError('Cached instance type mismatch.')
227+
228+
if alias is not None and (cached_instance := self._cache.by_alias[cls][alias].get(storage_client_cache_key)):
229+
if isinstance(cached_instance, cls):
230+
return cached_instance
231+
raise RuntimeError('Cached instance type mismatch.')
232+
233+
return None
234+
235+
def _check_name_alias_conflict(
236+
self,
237+
cls: type[T],
238+
*,
239+
name: str | None = None,
240+
alias: str | None = None,
241+
storage_client_cache_key: Hashable = '',
242+
) -> None:
243+
"""Check for conflicts between named and alias storages."""
244+
if alias and (self._cache.by_name[cls][alias].get(storage_client_cache_key)):
245+
raise ValueError(
246+
f'Cannot create alias storage "{alias}" because a named storage with the same name already exists. '
247+
f'Use a different alias or drop the existing named storage first.'
248+
)
249+
250+
if name and (self._cache.by_alias[cls][name].get(storage_client_cache_key)):
251+
raise ValueError(
252+
f'Cannot create named storage "{name}" because an alias storage with the same name already exists. '
253+
f'Use a different name or drop the existing alias storage first.'
254+
)

tests/unit/storages/test_storage_instance_manager.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import asyncio
2+
import sys
13
from pathlib import Path
24
from typing import cast
5+
from unittest.mock import AsyncMock
36

47
import pytest
58

@@ -128,3 +131,61 @@ async def test_preexisting_unnamed_storage_open_by_id(storage_type: type[Storage
128131
storage_1_again = await storage_type.open(id=storage_1.id, storage_client=storage_client)
129132

130133
assert storage_1.id == storage_1_again.id
134+
135+
136+
@pytest.mark.skipif(sys.version_info[:3] < (3, 11), reason='asyncio.Barrier was introduced in Python 3.11.')
137+
async def test_concurrent_open_datasets() -> None:
138+
"""Test that concurrent open datasets with the same name return the same instance."""
139+
from asyncio import Barrier # type:ignore[attr-defined] # noqa: PLC0415
140+
141+
barrier = Barrier(2)
142+
143+
async def push_data(data: dict) -> None:
144+
await barrier.wait()
145+
dataset = await Dataset.open(name='concurrent-storage')
146+
await dataset.push_data(data)
147+
148+
await asyncio.gather(
149+
push_data({'test_1': '1'}),
150+
push_data({'test_2': '2'}),
151+
)
152+
153+
dataset = await Dataset.open(name='concurrent-storage')
154+
155+
items = await dataset.get_data()
156+
assert len(items.items) == 2
157+
158+
await dataset.drop()
159+
160+
161+
@pytest.mark.skipif(sys.version_info[:3] < (3, 11), reason='asyncio.Barrier was introduced in Python 3.11.')
162+
async def test_concurrent_open_datasets_with_same_name_and_alias() -> None:
163+
"""Test that concurrent open requests for the same storage return the same instance."""
164+
from asyncio import Barrier # type:ignore[attr-defined] # noqa: PLC0415
165+
166+
valid_kwargs: dict[str, str | None] = {}
167+
168+
exception_calls = AsyncMock()
169+
170+
barrier = Barrier(2)
171+
172+
async def open_dataset(name: str | None, alias: str | None) -> None:
173+
await barrier.wait()
174+
try:
175+
await Dataset.open(name=name, alias=alias)
176+
valid_kwargs['name'] = name
177+
valid_kwargs['alias'] = alias
178+
except ValueError:
179+
exception_calls()
180+
181+
await asyncio.gather(
182+
open_dataset(name=None, alias='concurrent-storage'),
183+
open_dataset(name='concurrent-storage', alias=None),
184+
)
185+
186+
# Ensure that a ValueError was raised due to name/alias conflict
187+
exception_calls.assert_called_once()
188+
189+
dataset = await Dataset.open(name=valid_kwargs.get('name'), alias=valid_kwargs.get('alias'))
190+
191+
await dataset.drop()

0 commit comments

Comments
 (0)