diff --git a/src/crawlee/_utils/recoverable_state.py b/src/crawlee/_utils/recoverable_state.py index fc115eb403..9916083a72 100644 --- a/src/crawlee/_utils/recoverable_state.py +++ b/src/crawlee/_utils/recoverable_state.py @@ -4,12 +4,14 @@ from pydantic import BaseModel +from crawlee._utils.raise_if_too_many_kwargs import raise_if_too_many_kwargs from crawlee.events._types import Event, EventPersistStateData if TYPE_CHECKING: import logging + from collections.abc import Callable, Coroutine - from crawlee.storages._key_value_store import KeyValueStore + from crawlee.storages import KeyValueStore TStateModel = TypeVar('TStateModel', bound=BaseModel) @@ -37,6 +39,7 @@ def __init__( persistence_enabled: Literal[True, False, 'explicit_only'] = False, persist_state_kvs_name: str | None = None, persist_state_kvs_id: str | None = None, + persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None, logger: logging.Logger, ) -> None: """Initialize a new recoverable state object. @@ -51,16 +54,40 @@ def __init__( If neither a name nor and id are supplied, the default store will be used. persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence. If neither a name nor and id are supplied, the default store will be used. + persist_state_kvs_factory: Factory that can be awaited to create KeyValueStore to use for persistence. If + not provided, a system-wide KeyValueStore will be used, based on service locator configuration. logger: A logger instance for logging operations related to state persistence """ + raise_if_too_many_kwargs( + persist_state_kvs_name=persist_state_kvs_name, + persist_state_kvs_id=persist_state_kvs_id, + persist_state_kvs_factory=persist_state_kvs_factory, + ) + if not persist_state_kvs_factory: + logger.debug( + 'No explicit key_value_store set for recoverable state. Recovery will use a system-wide KeyValueStore ' + 'based on service_locator configuration, potentially calling service_locator.set_storage_client in the ' + 'process. It is recommended to initialize RecoverableState with explicit key_value_store to avoid ' + 'global side effects.' + ) + self._default_state = default_state self._state_type: type[TStateModel] = self._default_state.__class__ self._state: TStateModel | None = None self._persistence_enabled = persistence_enabled self._persist_state_key = persist_state_key - self._persist_state_kvs_name = persist_state_kvs_name - self._persist_state_kvs_id = persist_state_kvs_id - self._key_value_store: 'KeyValueStore | None' = None # noqa: UP037 + if persist_state_kvs_factory is None: + + async def kvs_factory() -> KeyValueStore: + from crawlee.storages import KeyValueStore # noqa: PLC0415 avoid circular import + + return await KeyValueStore.open(name=persist_state_kvs_name, id=persist_state_kvs_id) + + self._persist_state_kvs_factory = kvs_factory + else: + self._persist_state_kvs_factory = persist_state_kvs_factory + + self._key_value_store: KeyValueStore | None = None self._log = logger async def initialize(self) -> TStateModel: @@ -77,11 +104,8 @@ async def initialize(self) -> TStateModel: return self.current_value # Import here to avoid circular imports. - from crawlee.storages._key_value_store import KeyValueStore # noqa: PLC0415 - self._key_value_store = await KeyValueStore.open( - name=self._persist_state_kvs_name, id=self._persist_state_kvs_id - ) + self._key_value_store = await self._persist_state_kvs_factory() await self._load_saved_state() diff --git a/src/crawlee/statistics/_statistics.py b/src/crawlee/statistics/_statistics.py index 2386986001..3e95932c2a 100644 --- a/src/crawlee/statistics/_statistics.py +++ b/src/crawlee/statistics/_statistics.py @@ -17,8 +17,11 @@ from crawlee.statistics._error_tracker import ErrorTracker if TYPE_CHECKING: + from collections.abc import Callable, Coroutine from types import TracebackType + from crawlee.storages import KeyValueStore + TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState) TNewStatisticsState = TypeVar('TNewStatisticsState', bound=StatisticsState, default=StatisticsState) logger = getLogger(__name__) @@ -70,6 +73,7 @@ def __init__( persistence_enabled: bool | Literal['explicit_only'] = False, persist_state_kvs_name: str | None = None, persist_state_key: str | None = None, + persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None, log_message: str = 'Statistics', periodic_message_logger: Logger | None = None, log_interval: timedelta = timedelta(minutes=1), @@ -95,6 +99,7 @@ def __init__( persist_state_key=persist_state_key or f'SDK_CRAWLER_STATISTICS_{self._id}', persistence_enabled=persistence_enabled, persist_state_kvs_name=persist_state_kvs_name, + persist_state_kvs_factory=persist_state_kvs_factory, logger=logger, ) @@ -110,8 +115,8 @@ def replace_state_model(self, state_model: type[TNewStatisticsState]) -> Statist """Create near copy of the `Statistics` with replaced `state_model`.""" new_statistics: Statistics[TNewStatisticsState] = Statistics( persistence_enabled=self._state._persistence_enabled, # noqa: SLF001 - persist_state_kvs_name=self._state._persist_state_kvs_name, # noqa: SLF001 persist_state_key=self._state._persist_state_key, # noqa: SLF001 + persist_state_kvs_factory=self._state._persist_state_kvs_factory, # noqa: SLF001 log_message=self._log_message, periodic_message_logger=self._periodic_message_logger, state_model=state_model, diff --git a/src/crawlee/storage_clients/_file_system/_request_queue_client.py b/src/crawlee/storage_clients/_file_system/_request_queue_client.py index 426d78d3e1..b831a4cfb4 100644 --- a/src/crawlee/storage_clients/_file_system/_request_queue_client.py +++ b/src/crawlee/storage_clients/_file_system/_request_queue_client.py @@ -31,6 +31,7 @@ from collections.abc import Sequence from crawlee.configuration import Configuration + from crawlee.storages import KeyValueStore logger = getLogger(__name__) @@ -92,6 +93,7 @@ def __init__( metadata: RequestQueueMetadata, path_to_rq: Path, lock: asyncio.Lock, + recoverable_state: RecoverableState[RequestQueueState], ) -> None: """Initialize a new instance. @@ -114,12 +116,7 @@ def __init__( self._is_empty_cache: bool | None = None """Cache for is_empty result: None means unknown, True/False is cached state.""" - self._state = RecoverableState[RequestQueueState]( - default_state=RequestQueueState(), - persist_state_key=f'__RQ_STATE_{self._metadata.id}', - persistence_enabled=True, - logger=logger, - ) + self._state = recoverable_state """Recoverable state to maintain request ordering, in-progress status, and handled status.""" @override @@ -136,6 +133,22 @@ def path_to_metadata(self) -> Path: """The full path to the request queue metadata file.""" return self.path_to_rq / METADATA_FILENAME + @classmethod + async def _create_recoverable_state(cls, id: str, configuration: Configuration) -> RecoverableState: + async def kvs_factory() -> KeyValueStore: + from crawlee.storage_clients import FileSystemStorageClient # noqa: PLC0415 avoid circular import + from crawlee.storages import KeyValueStore # noqa: PLC0415 avoid circular import + + return await KeyValueStore.open(storage_client=FileSystemStorageClient(), configuration=configuration) + + return RecoverableState[RequestQueueState]( + default_state=RequestQueueState(), + persist_state_key=f'__RQ_STATE_{id}', + persist_state_kvs_factory=kvs_factory, + persistence_enabled=True, + logger=logger, + ) + @classmethod async def open( cls, @@ -194,6 +207,9 @@ async def open( metadata=metadata, path_to_rq=rq_base_path / rq_dir, lock=asyncio.Lock(), + recoverable_state=await cls._create_recoverable_state( + id=id, configuration=configuration + ), ) await client._state.initialize() await client._discover_existing_requests() @@ -230,6 +246,7 @@ async def open( metadata=metadata, path_to_rq=path_to_rq, lock=asyncio.Lock(), + recoverable_state=await cls._create_recoverable_state(id=metadata.id, configuration=configuration), ) await client._state.initialize() @@ -254,6 +271,7 @@ async def open( metadata=metadata, path_to_rq=path_to_rq, lock=asyncio.Lock(), + recoverable_state=await cls._create_recoverable_state(id=metadata.id, configuration=configuration), ) await client._state.initialize() await client._update_metadata() diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index 3260e4f91e..96bac8f7b3 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -278,11 +278,14 @@ async def get_auto_saved_value( if key in cache: return cache[key].current_value.root + async def kvs_factory() -> KeyValueStore: + return self + cache[key] = recoverable_state = RecoverableState( default_state=AutosavedValue(default_value), - persistence_enabled=True, - persist_state_kvs_id=self.id, persist_state_key=key, + persistence_enabled=True, + persist_state_kvs_factory=kvs_factory, logger=logger, ) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index cc2d3e8769..9956167931 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -17,6 +17,7 @@ from crawlee.fingerprint_suite._browserforge_adapter import get_available_header_network from crawlee.http_clients import CurlImpersonateHttpClient, HttpxHttpClient, ImpitHttpClient from crawlee.proxy_configuration import ProxyInfo +from crawlee.statistics import Statistics from crawlee.storages import KeyValueStore from tests.unit.server import TestServer, app, serve_in_thread @@ -69,6 +70,10 @@ def _prepare_test_env() -> None: # Verify that the test environment was set up correctly. assert os.environ.get('CRAWLEE_STORAGE_DIR') == str(tmp_path) + # Reset global class variables to ensure test isolation. + KeyValueStore._autosaved_values = {} + Statistics._Statistics__next_id = 0 # type:ignore[attr-defined] # Mangled attribute + return _prepare_test_env diff --git a/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py b/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py index 5fdb621718..5a9070e9ca 100644 --- a/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py +++ b/tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py @@ -493,7 +493,6 @@ async def test_adaptive_playwright_crawler_statistics_in_init() -> None: assert type(crawler._statistics.state) is AdaptivePlaywrightCrawlerStatisticState assert crawler._statistics._state._persistence_enabled == persistence_enabled - assert crawler._statistics._state._persist_state_kvs_name == persist_state_kvs_name assert crawler._statistics._state._persist_state_key == persist_state_key assert crawler._statistics._log_message == log_message diff --git a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py index dc2937a259..6dc8c837d5 100644 --- a/tests/unit/storage_clients/_file_system/test_fs_rq_client.py +++ b/tests/unit/storage_clients/_file_system/test_fs_rq_client.py @@ -6,9 +6,9 @@ import pytest -from crawlee import Request +from crawlee import Request, service_locator from crawlee.configuration import Configuration -from crawlee.storage_clients import FileSystemStorageClient +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -78,6 +78,14 @@ async def test_request_file_persistence(rq_client: FileSystemRequestQueueClient) assert request_data['url'].startswith('https://example.com/') +async def test_opening_rq_does_not_have_side_effect_on_service_locator(configuration: Configuration) -> None: + """Opening request queue client should cause setting storage client in the global service locator.""" + await FileSystemStorageClient().create_rq_client(name='test_request_queue', configuration=configuration) + + # Set some specific storage client in the service locator. There should be no `ServiceConflictError`. + service_locator.set_storage_client(MemoryStorageClient()) + + async def test_drop_removes_directory(rq_client: FileSystemRequestQueueClient) -> None: """Test that drop removes the entire RQ directory from disk.""" await rq_client.add_batch_of_requests([Request.from_url('https://example.com')]) diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index 21cdf6ad1b..f02d5a72cf 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -10,13 +10,13 @@ from crawlee import service_locator from crawlee.configuration import Configuration +from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient, SqlStorageClient, StorageClient from crawlee.storages import KeyValueStore from crawlee.storages._storage_instance_manager import StorageInstanceManager if TYPE_CHECKING: from collections.abc import AsyncGenerator - - from crawlee.storage_clients import StorageClient + from pathlib import Path @pytest.fixture @@ -1063,3 +1063,43 @@ async def test_name_default_not_allowed(storage_client: StorageClient) -> None: f'it is reserved for default alias.', ): await KeyValueStore.open(name=StorageInstanceManager._DEFAULT_STORAGE_ALIAS, storage_client=storage_client) + + +@pytest.mark.parametrize( + 'tested_storage_client', + [ + pytest.param(MemoryStorageClient(), id='tested=MemoryStorageClient'), + pytest.param(FileSystemStorageClient(), id='tested=FileSystemStorageClient'), + pytest.param(SqlStorageClient(), id='tested=SqlStorageClient'), + ], +) +@pytest.mark.parametrize( + 'global_storage_client', + [ + pytest.param(MemoryStorageClient(), id='global=MemoryStorageClient'), + pytest.param(FileSystemStorageClient(), id='global=FileSystemStorageClient'), + pytest.param(SqlStorageClient(), id='global=SqlStorageClient'), + ], +) +async def test_get_auto_saved_value_various_global_clients( + tmp_path: Path, tested_storage_client: StorageClient, global_storage_client: StorageClient +) -> None: + """Ensure that persistence is working for all clients regardless of what is set in service locator.""" + service_locator.set_configuration( + Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + purge_on_start=True, + ) + ) + service_locator.set_storage_client(global_storage_client) + + kvs = await KeyValueStore.open(storage_client=tested_storage_client) + values_kvs = {'key': 'some_value'} + test_key = 'test_key' + + autosaved_value_kvs = await kvs.get_auto_saved_value(test_key) + assert autosaved_value_kvs == {} + autosaved_value_kvs.update(values_kvs) + await kvs.persist_autosaved_values() + + assert await kvs.get_value(test_key) == autosaved_value_kvs