Skip to content

Commit 842922b

Browse files
committed
storage protection during parallel creation
1 parent 33e2ba9 commit 842922b

File tree

2 files changed

+166
-42
lines changed

2 files changed

+166
-42
lines changed

src/crawlee/storages/_storage_instance_manager.py

Lines changed: 105 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
from __future__ import annotations
22

3+
from asyncio import Lock
34
from collections import defaultdict
45
from collections.abc import Coroutine, Hashable
56
from dataclasses import dataclass, field
67
from typing import TYPE_CHECKING, TypeVar
8+
from weakref import WeakValueDictionary
79

810
from crawlee._utils.raise_if_too_many_kwargs import raise_if_too_many_kwargs
911
from crawlee.storage_clients._base import DatasetClient, KeyValueStoreClient, RequestQueueClient
@@ -76,6 +78,7 @@ class StorageInstanceManager:
7678

7779
def __init__(self) -> None:
7880
self._cache: _StorageCache = _StorageCache()
81+
self._opener_locks: WeakValueDictionary[tuple, Lock] = WeakValueDictionary()
7982

8083
async def open_storage_instance(
8184
self,
@@ -120,62 +123,74 @@ async def open_storage_instance(
120123
alias = self._DEFAULT_STORAGE_ALIAS
121124

122125
# Check cache
123-
if id is not None and (cached_instance := self._cache.by_id[cls][id].get(storage_client_cache_key)):
124-
if isinstance(cached_instance, cls):
125-
return cached_instance
126-
raise RuntimeError('Cached instance type mismatch.')
127-
128-
if name is not None and (cached_instance := self._cache.by_name[cls][name].get(storage_client_cache_key)):
129-
if isinstance(cached_instance, cls):
130-
return cached_instance
131-
raise RuntimeError('Cached instance type mismatch.')
132-
133-
if alias is not None and (
134-
cached_instance := self._cache.by_alias[cls][alias].get(storage_client_cache_key)
126+
if cached_instance := self._get_from_cache(
127+
cls,
128+
id=id,
129+
name=name,
130+
alias=alias,
131+
storage_client_cache_key=storage_client_cache_key,
135132
):
136-
if isinstance(cached_instance, cls):
137-
return cached_instance
138-
raise RuntimeError('Cached instance type mismatch.')
133+
return cached_instance
139134

140135
# Check for conflicts between named and alias storages
141-
if alias and (self._cache.by_name[cls][alias].get(storage_client_cache_key)):
142-
raise ValueError(
143-
f'Cannot create alias storage "{alias}" because a named storage with the same name already exists. '
144-
f'Use a different alias or drop the existing named storage first.'
145-
)
146-
147-
if name and (self._cache.by_alias[cls][name].get(storage_client_cache_key)):
148-
raise ValueError(
149-
f'Cannot create named storage "{name}" because an alias storage with the same name already exists. '
150-
f'Use a different name or drop the existing alias storage first.'
151-
)
136+
self._check_name_alias_conflict(
137+
cls,
138+
name=name,
139+
alias=alias,
140+
storage_client_cache_key=storage_client_cache_key,
141+
)
152142

153143
# Validate storage name
154144
if name is not None:
155145
validate_storage_name(name)
156146

157-
# Create new instance
158-
client: KeyValueStoreClient | DatasetClient | RequestQueueClient
159-
client = await client_opener_coro
147+
# Acquire lock for this opener
148+
opener_lock_key = (cls, str(id or name or alias), storage_client_cache_key)
149+
if not (lock := self._opener_locks.get(opener_lock_key)):
150+
lock = Lock()
151+
self._opener_locks[opener_lock_key] = lock
152+
153+
async with lock:
154+
# Re-check cache inside the lock
155+
if cached_instance := self._get_from_cache(
156+
cls,
157+
id=id,
158+
name=name,
159+
alias=alias,
160+
storage_client_cache_key=storage_client_cache_key,
161+
):
162+
return cached_instance
160163

161-
metadata = await client.get_metadata()
164+
# Re-check for conflicts between named and alias storages
165+
self._check_name_alias_conflict(
166+
cls,
167+
name=name,
168+
alias=alias,
169+
storage_client_cache_key=storage_client_cache_key,
170+
)
162171

163-
instance = cls(client, metadata.id, metadata.name) # type: ignore[call-arg]
164-
instance_name = getattr(instance, 'name', None)
172+
# Create new instance
173+
client: KeyValueStoreClient | DatasetClient | RequestQueueClient
174+
client = await client_opener_coro
165175

166-
# Cache the instance.
167-
# Always cache by id.
168-
self._cache.by_id[cls][instance.id][storage_client_cache_key] = instance
176+
metadata = await client.get_metadata()
169177

170-
# Cache named storage.
171-
if instance_name is not None:
172-
self._cache.by_name[cls][instance_name][storage_client_cache_key] = instance
178+
instance = cls(client, metadata.id, metadata.name) # type: ignore[call-arg]
179+
instance_name = getattr(instance, 'name', None)
173180

174-
# Cache unnamed storage.
175-
if alias is not None:
176-
self._cache.by_alias[cls][alias][storage_client_cache_key] = instance
181+
# Cache the instance.
182+
# Always cache by id.
183+
self._cache.by_id[cls][instance.id][storage_client_cache_key] = instance
177184

178-
return instance
185+
# Cache named storage.
186+
if instance_name is not None:
187+
self._cache.by_name[cls][instance_name][storage_client_cache_key] = instance
188+
189+
# Cache unnamed storage.
190+
if alias is not None:
191+
self._cache.by_alias[cls][alias][storage_client_cache_key] = instance
192+
193+
return instance
179194

180195
finally:
181196
# Make sure the client opener is closed.
@@ -193,3 +208,51 @@ def remove_from_cache(self, storage_instance: Storage) -> None:
193208
def clear_cache(self) -> None:
194209
"""Clear all cached storage instances."""
195210
self._cache = _StorageCache()
211+
212+
def _get_from_cache(
213+
self,
214+
cls: type[T],
215+
*,
216+
id: str | None = None,
217+
name: str | None = None,
218+
alias: str | None = None,
219+
storage_client_cache_key: Hashable = '',
220+
) -> T | None:
221+
"""Get a storage instance from the cache."""
222+
if id is not None and (cached_instance := self._cache.by_id[cls][id].get(storage_client_cache_key)):
223+
if isinstance(cached_instance, cls):
224+
return cached_instance
225+
raise RuntimeError('Cached instance type mismatch.')
226+
227+
if name is not None and (cached_instance := self._cache.by_name[cls][name].get(storage_client_cache_key)):
228+
if isinstance(cached_instance, cls):
229+
return cached_instance
230+
raise RuntimeError('Cached instance type mismatch.')
231+
232+
if alias is not None and (cached_instance := self._cache.by_alias[cls][alias].get(storage_client_cache_key)):
233+
if isinstance(cached_instance, cls):
234+
return cached_instance
235+
raise RuntimeError('Cached instance type mismatch.')
236+
237+
return None
238+
239+
def _check_name_alias_conflict(
240+
self,
241+
cls: type[T],
242+
*,
243+
name: str | None = None,
244+
alias: str | None = None,
245+
storage_client_cache_key: Hashable = '',
246+
) -> None:
247+
"""Check for conflicts between named and alias storages."""
248+
if alias and (self._cache.by_name[cls][alias].get(storage_client_cache_key)):
249+
raise ValueError(
250+
f'Cannot create alias storage "{alias}" because a named storage with the same name already exists. '
251+
f'Use a different alias or drop the existing named storage first.'
252+
)
253+
254+
if name and (self._cache.by_alias[cls][name].get(storage_client_cache_key)):
255+
raise ValueError(
256+
f'Cannot create named storage "{name}" because an alias storage with the same name already exists. '
257+
f'Use a different name or drop the existing alias storage first.'
258+
)

tests/unit/storages/test_storage_instance_manager.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import asyncio
2+
import sys
13
from pathlib import Path
24
from typing import cast
5+
from unittest.mock import AsyncMock
36

47
import pytest
58

@@ -128,3 +131,61 @@ async def test_preexisting_unnamed_storage_open_by_id(storage_type: type[Storage
128131
storage_1_again = await storage_type.open(id=storage_1.id, storage_client=storage_client)
129132

130133
assert storage_1.id == storage_1_again.id
134+
135+
136+
@pytest.mark.skipif(sys.version_info[:3] < (3, 11), reason='asyncio.Barrier was introduced in Python 3.11.')
137+
async def test_concurrent_open_datasets() -> None:
138+
"""Test that concurrent open datasets with the same name return the same instance."""
139+
from asyncio import Barrier # type:ignore[attr-defined] # noqa: PLC0415
140+
141+
barrier = Barrier(2)
142+
143+
async def push_data(data: dict) -> None:
144+
await barrier.wait()
145+
dataset = await Dataset.open(name='concurrent-storage')
146+
await dataset.push_data(data)
147+
148+
await asyncio.gather(
149+
push_data({'test_1': '1'}),
150+
push_data({'test_2': '2'}),
151+
)
152+
153+
dataset = await Dataset.open(name='concurrent-storage')
154+
155+
items = await dataset.get_data()
156+
assert len(items.items) == 2
157+
158+
await dataset.drop()
159+
160+
161+
@pytest.mark.skipif(sys.version_info[:3] < (3, 11), reason='asyncio.Barrier was introduced in Python 3.11.')
162+
async def test_concurrent_open_datasets_with_same_name_and_alias() -> None:
163+
"""Test that concurrent open requests for the same storage return the same instance."""
164+
from asyncio import Barrier # type:ignore[attr-defined] # noqa: PLC0415
165+
166+
valid_kwargs: dict[str, str | None] = {}
167+
168+
exception_calls = AsyncMock()
169+
170+
barrier = Barrier(2)
171+
172+
async def open_dataset(name: str | None, alias: str | None) -> None:
173+
await barrier.wait()
174+
try:
175+
await Dataset.open(name=name, alias=alias)
176+
valid_kwargs['name'] = name
177+
valid_kwargs['alias'] = alias
178+
except ValueError:
179+
exception_calls()
180+
181+
await asyncio.gather(
182+
open_dataset(name=None, alias='concurrent-storage'),
183+
open_dataset(name='concurrent-storage', alias=None),
184+
)
185+
186+
# Ensure that a ValueError was raised due to name/alias conflict
187+
exception_calls.assert_called_once()
188+
189+
dataset = await Dataset.open(name=valid_kwargs.get('name'), alias=valid_kwargs.get('alias'))
190+
191+
await dataset.drop()

0 commit comments

Comments
 (0)