diff --git a/src/crawlee/_utils/recoverable_state.py b/src/crawlee/_utils/recoverable_state.py index fc115eb403..f01b8e3456 100644 --- a/src/crawlee/_utils/recoverable_state.py +++ b/src/crawlee/_utils/recoverable_state.py @@ -38,6 +38,7 @@ def __init__( persist_state_kvs_name: str | None = None, persist_state_kvs_id: str | None = None, logger: logging.Logger, + key_value_store: None | KeyValueStore = None, ) -> None: """Initialize a new recoverable state object. @@ -52,7 +53,14 @@ def __init__( persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence. If neither a name nor and id are supplied, the default store will be used. logger: A logger instance for logging operations related to state persistence + key_value_store: KeyValueStore to use for persistence. If not provided, a system-wide KeyValueStore will be + used, based on service locator configuration. """ + if key_value_store and (persist_state_kvs_name or persist_state_kvs_id): + raise ValueError( + 'Cannot provide explicit key_value_store and persist_state_kvs_name or persist_state_kvs_id.' + ) + self._default_state = default_state self._state_type: type[TStateModel] = self._default_state.__class__ self._state: TStateModel | None = None @@ -60,8 +68,8 @@ def __init__( self._persist_state_key = persist_state_key self._persist_state_kvs_name = persist_state_kvs_name self._persist_state_kvs_id = persist_state_kvs_id - self._key_value_store: 'KeyValueStore | None' = None # noqa: UP037 self._log = logger + self._key_value_store = key_value_store async def initialize(self) -> TStateModel: """Initialize the recoverable state. @@ -79,9 +87,10 @@ async def initialize(self) -> TStateModel: # Import here to avoid circular imports. from crawlee.storages._key_value_store import KeyValueStore # noqa: PLC0415 - self._key_value_store = await KeyValueStore.open( - name=self._persist_state_kvs_name, id=self._persist_state_kvs_id - ) + if not self._key_value_store: + self._key_value_store = await KeyValueStore.open( + name=self._persist_state_kvs_name, id=self._persist_state_kvs_id + ) await self._load_saved_state() diff --git a/src/crawlee/storage_clients/_file_system/_request_queue_client.py b/src/crawlee/storage_clients/_file_system/_request_queue_client.py index f5e0165d68..2bed1188e5 100644 --- a/src/crawlee/storage_clients/_file_system/_request_queue_client.py +++ b/src/crawlee/storage_clients/_file_system/_request_queue_client.py @@ -90,6 +90,7 @@ def __init__( metadata: RequestQueueMetadata, storage_dir: Path, lock: asyncio.Lock, + recoverable_state: RecoverableState[RequestQueueState], ) -> None: """Initialize a new instance. @@ -112,13 +113,7 @@ def __init__( self._is_empty_cache: bool | None = None """Cache for is_empty result: None means unknown, True/False is cached state.""" - self._state = RecoverableState[RequestQueueState]( - default_state=RequestQueueState(), - persist_state_key='request_queue_state', - persistence_enabled=True, - persist_state_kvs_name=f'__RQ_STATE_{self._metadata.id}', - logger=logger, - ) + self._state = recoverable_state """Recoverable state to maintain request ordering, in-progress status, and handled status.""" @override @@ -187,14 +182,9 @@ async def open( metadata = RequestQueueMetadata(**file_content) if metadata.id == id: - client = cls( - metadata=metadata, - storage_dir=storage_dir, - lock=asyncio.Lock(), + client = await cls._create_client( + metadata=metadata, storage_dir=storage_dir, update_accessed_at=True ) - await client._state.initialize() - await client._discover_existing_requests() - await client._update_metadata(update_accessed_at=True) found = True break finally: @@ -224,15 +214,7 @@ async def open( metadata.name = name - client = cls( - metadata=metadata, - storage_dir=storage_dir, - lock=asyncio.Lock(), - ) - - await client._state.initialize() - await client._discover_existing_requests() - await client._update_metadata(update_accessed_at=True) + client = await cls._create_client(metadata=metadata, storage_dir=storage_dir, update_accessed_at=True) # Otherwise, create a new dataset client. else: @@ -248,13 +230,40 @@ async def open( pending_request_count=0, total_request_count=0, ) - client = cls( - metadata=metadata, - storage_dir=storage_dir, - lock=asyncio.Lock(), - ) - await client._state.initialize() - await client._update_metadata() + client = await cls._create_client(metadata=metadata, storage_dir=storage_dir) + + return client + + @classmethod + async def _create_client( + cls, metadata: RequestQueueMetadata, storage_dir: Path, *, update_accessed_at: bool = False + ) -> FileSystemRequestQueueClient: + """Create client from metadata and storage directory.""" + from crawlee.storage_clients import FileSystemStorageClient # noqa: PLC0415 avoid circular imports + from crawlee.storages._key_value_store import KeyValueStore # noqa: PLC0415 avoid circular imports + + # Prepare kvs for recoverable state + kvs_client = await FileSystemStorageClient().create_kvs_client(name=f'__RQ_STATE_{metadata.id}') + kvs_client_metadata = await kvs_client.get_metadata() + kvs = KeyValueStore(client=kvs_client, id=kvs_client_metadata.id, name=kvs_client_metadata.name) + + # Create state + recoverable_state = RecoverableState[RequestQueueState]( + default_state=RequestQueueState(), + persist_state_key='request_queue_state', + persistence_enabled=True, + logger=logger, + key_value_store=kvs, + ) + + # Create client + client = cls( + metadata=metadata, storage_dir=storage_dir, lock=asyncio.Lock(), recoverable_state=recoverable_state + ) + + await client._state.initialize() + await client._discover_existing_requests() + await client._update_metadata(update_accessed_at=update_accessed_at) return client diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index 5297925a37..66e5bc666b 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -12,7 +12,6 @@ from crawlee._types import JsonSerializable # noqa: TC001 from crawlee._utils.docs import docs_group from crawlee._utils.recoverable_state import RecoverableState -from crawlee.storage_clients.models import KeyValueStoreMetadata from ._base import Storage @@ -23,8 +22,7 @@ from crawlee.storage_clients import StorageClient from crawlee.storage_clients._base import KeyValueStoreClient from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata -else: - from crawlee._utils.recoverable_state import RecoverableState + T = TypeVar('T') @@ -274,9 +272,9 @@ async def get_auto_saved_value( cache[key] = recoverable_state = RecoverableState( default_state=AutosavedValue(default_value), persistence_enabled=True, - persist_state_kvs_id=self.id, persist_state_key=key, logger=logger, + key_value_store=self, # Use self for RecoverableState. ) await recoverable_state.initialize() diff --git a/src/crawlee/storages/_storage_instance_manager.py b/src/crawlee/storages/_storage_instance_manager.py index 130a2eec63..a9860135a2 100644 --- a/src/crawlee/storages/_storage_instance_manager.py +++ b/src/crawlee/storages/_storage_instance_manager.py @@ -1,15 +1,17 @@ from __future__ import annotations +from collections import defaultdict from collections.abc import Awaitable, Callable +from dataclasses import dataclass, field from typing import TYPE_CHECKING, TypeVar, cast from crawlee.storage_clients._base import DatasetClient, KeyValueStoreClient, RequestQueueClient -from ._base import Storage - if TYPE_CHECKING: from crawlee.configuration import Configuration + from ._base import Storage + T = TypeVar('T', bound='Storage') StorageClientType = DatasetClient | KeyValueStoreClient | RequestQueueClient @@ -19,6 +21,22 @@ """Type alias for the client opener function.""" +@dataclass +class _StorageClientCache: + """Cache for specific storage client.""" + + by_id: defaultdict[type[Storage], defaultdict[str, Storage]] = field( + default_factory=lambda: defaultdict(lambda: defaultdict()) + ) + """Cache for storage instances by ID, separated by storage type.""" + by_name: defaultdict[type[Storage], defaultdict[str, Storage]] = field( + default_factory=lambda: defaultdict(lambda: defaultdict()) + ) + """Cache for storage instances by name, separated by storage type.""" + default_instances: defaultdict[type[Storage], Storage] = field(default_factory=lambda: defaultdict()) + """Cache for default instances of each storage type.""" + + class StorageInstanceManager: """Manager for caching and managing storage instances. @@ -27,14 +45,7 @@ class StorageInstanceManager: """ def __init__(self) -> None: - self._cache_by_id = dict[type[Storage], dict[str, Storage]]() - """Cache for storage instances by ID, separated by storage type.""" - - self._cache_by_name = dict[type[Storage], dict[str, Storage]]() - """Cache for storage instances by name, separated by storage type.""" - - self._default_instances = dict[type[Storage], Storage]() - """Cache for default instances of each storage type.""" + self._cache_by_storage_client: dict[str, _StorageClientCache] = defaultdict(_StorageClientCache) async def open_storage_instance( self, @@ -64,19 +75,23 @@ async def open_storage_instance( raise ValueError('Only one of "id" or "name" can be specified, not both.') # Check for default instance - if id is None and name is None and cls in self._default_instances: - return cast('T', self._default_instances[cls]) + if ( + id is None + and name is None + and cls in self._cache_by_storage_client[client_opener.__qualname__].default_instances + ): + return cast('T', self._cache_by_storage_client[client_opener.__qualname__].default_instances[cls]) # Check cache if id is not None: - type_cache_by_id = self._cache_by_id.get(cls, {}) + type_cache_by_id = self._cache_by_storage_client[client_opener.__qualname__].by_id[cls] if id in type_cache_by_id: cached_instance = type_cache_by_id[id] if isinstance(cached_instance, cls): return cached_instance if name is not None: - type_cache_by_name = self._cache_by_name.get(cls, {}) + type_cache_by_name = self._cache_by_storage_client[client_opener.__qualname__].by_name[cls] if name in type_cache_by_name: cached_instance = type_cache_by_name[name] if isinstance(cached_instance, cls): @@ -90,16 +105,13 @@ async def open_storage_instance( instance_name = getattr(instance, 'name', None) # Cache the instance - type_cache_by_id = self._cache_by_id.setdefault(cls, {}) - type_cache_by_name = self._cache_by_name.setdefault(cls, {}) - - type_cache_by_id[instance.id] = instance + self._cache_by_storage_client[client_opener.__qualname__].by_id[cls][instance.id] = instance if instance_name is not None: - type_cache_by_name[instance_name] = instance + self._cache_by_storage_client[client_opener.__qualname__].by_name[cls][instance_name] = instance # Set as default if no id/name specified if id is None and name is None: - self._default_instances[cls] = instance + self._cache_by_storage_client[client_opener.__qualname__].default_instances[cls] = instance return instance @@ -112,22 +124,23 @@ def remove_from_cache(self, storage_instance: Storage) -> None: storage_type = type(storage_instance) # Remove from ID cache - type_cache_by_id = self._cache_by_id.get(storage_type, {}) - if storage_instance.id in type_cache_by_id: - del type_cache_by_id[storage_instance.id] - - # Remove from name cache - if storage_instance.name is not None: - type_cache_by_name = self._cache_by_name.get(storage_type, {}) - if storage_instance.name in type_cache_by_name: + for client_cache in self._cache_by_storage_client.values(): + type_cache_by_id = client_cache.by_id[storage_type] + if storage_instance.id in type_cache_by_id: + del type_cache_by_id[storage_instance.id] + + # Remove from name cache + type_cache_by_name = client_cache.by_name[storage_type] + if storage_instance.name in type_cache_by_name and storage_instance.name: del type_cache_by_name[storage_instance.name] - # Remove from default instances - if storage_type in self._default_instances and self._default_instances[storage_type] is storage_instance: - del self._default_instances[storage_type] + # Remove from default instances + if ( + storage_type in client_cache.default_instances + and client_cache.default_instances[storage_type] is storage_instance + ): + del client_cache.default_instances[storage_type] def clear_cache(self) -> None: """Clear all cached storage instances.""" - self._cache_by_id.clear() - self._cache_by_name.clear() - self._default_instances.clear() + self._cache_by_storage_client = defaultdict(_StorageClientCache) diff --git a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py index 0be182fcd8..9687f2df5e 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py @@ -7,8 +7,9 @@ import pytest from crawlee import Request +from crawlee._service_locator import service_locator from crawlee.configuration import Configuration -from crawlee.storage_clients import FileSystemStorageClient +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -55,7 +56,10 @@ async def test_file_and_directory_creation(configuration: Configuration) -> None await client.drop() -async def test_request_file_persistence(rq_client: FileSystemRequestQueueClient) -> None: +@pytest.mark.parametrize('set_different_storage_client_in_service_locator', [True, False]) +async def test_request_file_persistence( + rq_client: FileSystemRequestQueueClient, *, set_different_storage_client_in_service_locator: bool +) -> None: """Test that requests are properly persisted to files.""" requests = [ Request.from_url('https://example.com/1'), @@ -63,6 +67,9 @@ async def test_request_file_persistence(rq_client: FileSystemRequestQueueClient) Request.from_url('https://example.com/3'), ] + if set_different_storage_client_in_service_locator: + service_locator.set_storage_client(MemoryStorageClient()) + await rq_client.add_batch_of_requests(requests) # Verify request files are created diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index 25bbcb4fc0..a0d0e9b308 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -598,3 +598,30 @@ async def test_record_exists_after_purge(kvs: KeyValueStore) -> None: # Should no longer exist assert await kvs.record_exists('key1') is False assert await kvs.record_exists('key2') is False + + +async def test_get_auto_saved_value_with_multiple_storage_clients(tmp_path: Path) -> None: + """Test that setting storage client through service locator does not break autosaved values in other clients.""" + config = Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + purge_on_start=True, + ) + + kvs1 = await KeyValueStore.open(storage_client=MemoryStorageClient(), configuration=config) + + kvs2 = await KeyValueStore.open( + storage_client=FileSystemStorageClient(), + configuration=config, + ) + assert kvs1 is not kvs2 + + expected_values = {'key': 'value'} + test_key = 'test_key' + + autosaved_value = await kvs2.get_auto_saved_value(test_key) + assert autosaved_value == {} + autosaved_value.update(expected_values) + + await kvs2.persist_autosaved_values() + + assert await kvs2.get_value(test_key) == expected_values diff --git a/tests/unit/storages/test_storage_instance_manager.py b/tests/unit/storages/test_storage_instance_manager.py new file mode 100644 index 0000000000..eaa168df34 --- /dev/null +++ b/tests/unit/storages/test_storage_instance_manager.py @@ -0,0 +1,96 @@ +from pathlib import Path + +from crawlee import service_locator +from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient +from crawlee.storages import Dataset, KeyValueStore + + +async def test_unique_storage_by_storage_client(tmp_path: Path) -> None: + config = Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + purge_on_start=True, + ) + + kvs1 = await KeyValueStore.open(storage_client=MemoryStorageClient(), configuration=config) + + kvs2 = await KeyValueStore.open( + storage_client=FileSystemStorageClient(), + configuration=config, + ) + assert kvs1 is not kvs2 + + +async def test_unique_storage_by_storage_type(tmp_path: Path) -> None: + config = Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + purge_on_start=True, + ) + + kvs = await KeyValueStore.open(configuration=config) + dataset = await Dataset.open(configuration=config) + assert kvs is not dataset + + +async def test_unique_storage_by_name(tmp_path: Path) -> None: + """Test that StorageInstanceManager support different storage clients at the same time.""" + config = Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + purge_on_start=True, + ) + + kvs1 = await KeyValueStore.open(configuration=config, name='kvs1') + kvs2 = await KeyValueStore.open(storage_client=FileSystemStorageClient(), configuration=config, name='kvs2') + assert kvs1 is not kvs2 + + +async def test_identical_storage(tmp_path: Path) -> None: + """Test that StorageInstanceManager correctly caches storage based on the storage client.""" + config = Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + purge_on_start=True, + ) + + kvs1 = await KeyValueStore.open(storage_client=MemoryStorageClient(), configuration=config) + + kvs2 = await KeyValueStore.open( + storage_client=MemoryStorageClient(), + configuration=config, + ) + assert kvs1 is kvs2 + + +async def test_identical_storage_clear_cache(tmp_path: Path) -> None: + config = Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + purge_on_start=True, + ) + + kvs1 = await KeyValueStore.open(storage_client=MemoryStorageClient(), configuration=config) + + # Clearing cache, so expect different instances + service_locator.storage_instance_manager.clear_cache() + + kvs2 = await KeyValueStore.open( + storage_client=MemoryStorageClient(), + configuration=config, + ) + assert kvs1 is not kvs2 + + +async def test_identical_storage_remove_from_cache(tmp_path: Path) -> None: + config = Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + purge_on_start=True, + ) + + kvs1 = await KeyValueStore.open(storage_client=MemoryStorageClient(), configuration=config) + + # Removing from cache, so expect different instances + service_locator.storage_instance_manager.remove_from_cache(kvs1) + + kvs2 = await KeyValueStore.open( + storage_client=MemoryStorageClient(), + configuration=config, + ) + assert kvs1 is not kvs2