Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 32 additions & 8 deletions src/crawlee/_utils/recoverable_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@

from pydantic import BaseModel

from crawlee._utils.raise_if_too_many_kwargs import raise_if_too_many_kwargs
from crawlee.events._types import Event, EventPersistStateData

if TYPE_CHECKING:
import logging
from collections.abc import Callable, Coroutine

from crawlee.storages._key_value_store import KeyValueStore
from crawlee.storages import KeyValueStore

TStateModel = TypeVar('TStateModel', bound=BaseModel)

Expand Down Expand Up @@ -37,6 +39,7 @@ def __init__(
persistence_enabled: Literal[True, False, 'explicit_only'] = False,
persist_state_kvs_name: str | None = None,
persist_state_kvs_id: str | None = None,
persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None,
logger: logging.Logger,
) -> None:
"""Initialize a new recoverable state object.
Expand All @@ -51,16 +54,40 @@ def __init__(
If neither a name nor and id are supplied, the default store will be used.
persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence.
If neither a name nor and id are supplied, the default store will be used.
persist_state_kvs_factory: Factory that can be awaited to create KeyValueStore to use for persistence. If
not provided, a system-wide KeyValueStore will be used, based on service locator configuration.
logger: A logger instance for logging operations related to state persistence
"""
raise_if_too_many_kwargs(
persist_state_kvs_name=persist_state_kvs_name,
persist_state_kvs_id=persist_state_kvs_id,
persist_state_kvs_factory=persist_state_kvs_factory,
)
if not persist_state_kvs_factory:
logger.debug(
'No explicit key_value_store set for recoverable state. Recovery will use a system-wide KeyValueStore '
'based on service_locator configuration, potentially calling service_locator.set_storage_client in the '
'process. It is recommended to initialize RecoverableState with explicit key_value_store to avoid '
'global side effects.'
)

self._default_state = default_state
self._state_type: type[TStateModel] = self._default_state.__class__
self._state: TStateModel | None = None
self._persistence_enabled = persistence_enabled
self._persist_state_key = persist_state_key
self._persist_state_kvs_name = persist_state_kvs_name
self._persist_state_kvs_id = persist_state_kvs_id
self._key_value_store: 'KeyValueStore | None' = None # noqa: UP037
if persist_state_kvs_factory is None:

async def kvs_factory() -> KeyValueStore:
from crawlee.storages import KeyValueStore # noqa: PLC0415 avoid circular import

return await KeyValueStore.open(name=persist_state_kvs_name, id=persist_state_kvs_id)

self._persist_state_kvs_factory = kvs_factory
else:
self._persist_state_kvs_factory = persist_state_kvs_factory

self._key_value_store: KeyValueStore | None = None
self._log = logger

async def initialize(self) -> TStateModel:
Expand All @@ -77,11 +104,8 @@ async def initialize(self) -> TStateModel:
return self.current_value

# Import here to avoid circular imports.
from crawlee.storages._key_value_store import KeyValueStore # noqa: PLC0415

self._key_value_store = await KeyValueStore.open(
name=self._persist_state_kvs_name, id=self._persist_state_kvs_id
)
self._key_value_store = await self._persist_state_kvs_factory()

await self._load_saved_state()

Expand Down
7 changes: 6 additions & 1 deletion src/crawlee/statistics/_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
from crawlee.statistics._error_tracker import ErrorTracker

if TYPE_CHECKING:
from collections.abc import Callable, Coroutine
from types import TracebackType

from crawlee.storages import KeyValueStore

TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState)
TNewStatisticsState = TypeVar('TNewStatisticsState', bound=StatisticsState, default=StatisticsState)
logger = getLogger(__name__)
Expand Down Expand Up @@ -70,6 +73,7 @@ def __init__(
persistence_enabled: bool | Literal['explicit_only'] = False,
persist_state_kvs_name: str | None = None,
persist_state_key: str | None = None,
persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None,
log_message: str = 'Statistics',
periodic_message_logger: Logger | None = None,
log_interval: timedelta = timedelta(minutes=1),
Expand All @@ -95,6 +99,7 @@ def __init__(
persist_state_key=persist_state_key or f'SDK_CRAWLER_STATISTICS_{self._id}',
persistence_enabled=persistence_enabled,
persist_state_kvs_name=persist_state_kvs_name,
persist_state_kvs_factory=persist_state_kvs_factory,
logger=logger,
)

Expand All @@ -110,8 +115,8 @@ def replace_state_model(self, state_model: type[TNewStatisticsState]) -> Statist
"""Create near copy of the `Statistics` with replaced `state_model`."""
new_statistics: Statistics[TNewStatisticsState] = Statistics(
persistence_enabled=self._state._persistence_enabled, # noqa: SLF001
persist_state_kvs_name=self._state._persist_state_kvs_name, # noqa: SLF001
persist_state_key=self._state._persist_state_key, # noqa: SLF001
persist_state_kvs_factory=self._state._persist_state_kvs_factory, # noqa: SLF001
log_message=self._log_message,
periodic_message_logger=self._periodic_message_logger,
state_model=state_model,
Expand Down
30 changes: 24 additions & 6 deletions src/crawlee/storage_clients/_file_system/_request_queue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from collections.abc import Sequence

from crawlee.configuration import Configuration
from crawlee.storages import KeyValueStore

logger = getLogger(__name__)

Expand Down Expand Up @@ -92,6 +93,7 @@ def __init__(
metadata: RequestQueueMetadata,
path_to_rq: Path,
lock: asyncio.Lock,
recoverable_state: RecoverableState[RequestQueueState],
) -> None:
"""Initialize a new instance.

Expand All @@ -114,12 +116,7 @@ def __init__(
self._is_empty_cache: bool | None = None
"""Cache for is_empty result: None means unknown, True/False is cached state."""

self._state = RecoverableState[RequestQueueState](
default_state=RequestQueueState(),
persist_state_key=f'__RQ_STATE_{self._metadata.id}',
persistence_enabled=True,
logger=logger,
)
self._state = recoverable_state
"""Recoverable state to maintain request ordering, in-progress status, and handled status."""

@override
Expand All @@ -136,6 +133,22 @@ def path_to_metadata(self) -> Path:
"""The full path to the request queue metadata file."""
return self.path_to_rq / METADATA_FILENAME

@classmethod
async def _create_recoverable_state(cls, id: str, configuration: Configuration) -> RecoverableState:
async def kvs_factory() -> KeyValueStore:
from crawlee.storage_clients import FileSystemStorageClient # noqa: PLC0415 avoid circular import
from crawlee.storages import KeyValueStore # noqa: PLC0415 avoid circular import

return await KeyValueStore.open(storage_client=FileSystemStorageClient(), configuration=configuration)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating a fresh filesystem storage client to open a request queue feels wrong - at this point, we can be pretty sure that another one already exists. Is there a specific reason to do this or is it just because you don't have access to the existing one?

Copy link
Collaborator Author

@Pijukatel Pijukatel Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At that point, we are not sure it exists. It could have been created through a class method without the client and even when created with the help of client, it is out of scope:
await FileSystemRequestQueueClient.open(...)

And why not open the KVS through such a class method as well? Because that way, you bypass the storage instance manager - and that is generally something we do not want.

FileSystemStorageClient is just a helper factory class, which is mainly for convenience and for registering the storage instance manager.


return RecoverableState[RequestQueueState](
default_state=RequestQueueState(),
persist_state_key=f'__RQ_STATE_{id}',
persist_state_kvs_factory=kvs_factory,
persistence_enabled=True,
logger=logger,
)

@classmethod
async def open(
cls,
Expand Down Expand Up @@ -194,6 +207,9 @@ async def open(
metadata=metadata,
path_to_rq=rq_base_path / rq_dir,
lock=asyncio.Lock(),
recoverable_state=await cls._create_recoverable_state(
id=id, configuration=configuration
),
)
await client._state.initialize()
await client._discover_existing_requests()
Expand Down Expand Up @@ -230,6 +246,7 @@ async def open(
metadata=metadata,
path_to_rq=path_to_rq,
lock=asyncio.Lock(),
recoverable_state=await cls._create_recoverable_state(id=metadata.id, configuration=configuration),
)

await client._state.initialize()
Expand All @@ -254,6 +271,7 @@ async def open(
metadata=metadata,
path_to_rq=path_to_rq,
lock=asyncio.Lock(),
recoverable_state=await cls._create_recoverable_state(id=metadata.id, configuration=configuration),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of creating it three times, we can create it once, store it in a variable, and just pass it where needed.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That creation requires metadata to get the RQ.id, so we have to repeat this call, as in all three branches, we get metadata in a different way.

)
await client._state.initialize()
await client._update_metadata()
Expand Down
7 changes: 5 additions & 2 deletions src/crawlee/storages/_key_value_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,14 @@ async def get_auto_saved_value(
if key in cache:
return cache[key].current_value.root

async def kvs_factory() -> KeyValueStore:
return self

cache[key] = recoverable_state = RecoverableState(
default_state=AutosavedValue(default_value),
persistence_enabled=True,
persist_state_kvs_id=self.id,
persist_state_key=key,
persistence_enabled=True,
persist_state_kvs_factory=kvs_factory,
logger=logger,
)

Expand Down
5 changes: 5 additions & 0 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from crawlee.fingerprint_suite._browserforge_adapter import get_available_header_network
from crawlee.http_clients import CurlImpersonateHttpClient, HttpxHttpClient, ImpitHttpClient
from crawlee.proxy_configuration import ProxyInfo
from crawlee.statistics import Statistics
from crawlee.storages import KeyValueStore
from tests.unit.server import TestServer, app, serve_in_thread

Expand Down Expand Up @@ -69,6 +70,10 @@ def _prepare_test_env() -> None:
# Verify that the test environment was set up correctly.
assert os.environ.get('CRAWLEE_STORAGE_DIR') == str(tmp_path)

# Reset global class variables to ensure test isolation.
KeyValueStore._autosaved_values = {}
Statistics._Statistics__next_id = 0 # type:ignore[attr-defined] # Mangled attribute
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How does this contribute towards test isolation? Is there anything that depends on the persist state key that is derived from the ID?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The answer is I am not sure. But we have so many tests, so I think it is best if we restore all we can to the same state at the beginning of the test. This reduces the chance of some weird behavior based on the order of the test execution.


return _prepare_test_env


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,6 @@ async def test_adaptive_playwright_crawler_statistics_in_init() -> None:
assert type(crawler._statistics.state) is AdaptivePlaywrightCrawlerStatisticState

assert crawler._statistics._state._persistence_enabled == persistence_enabled
assert crawler._statistics._state._persist_state_kvs_name == persist_state_kvs_name
assert crawler._statistics._state._persist_state_key == persist_state_key

assert crawler._statistics._log_message == log_message
Expand Down
10 changes: 8 additions & 2 deletions tests/unit/storage_clients/_file_system/test_fs_rq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

import pytest

from crawlee import Request
from crawlee import Request, service_locator
from crawlee.configuration import Configuration
from crawlee.storage_clients import FileSystemStorageClient
from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient

if TYPE_CHECKING:
from collections.abc import AsyncGenerator
Expand Down Expand Up @@ -78,6 +78,12 @@ async def test_request_file_persistence(rq_client: FileSystemRequestQueueClient)
assert request_data['url'].startswith('https://example.com/')


async def test_opening_rq_does_not_have_side_effect_on_service_locator(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some explanation here? Also, inlining the rq_client fixture could lead to better readable code

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok

rq_client: FileSystemRequestQueueClient, # noqa: ARG001
) -> None:
service_locator.set_storage_client(MemoryStorageClient())


async def test_drop_removes_directory(rq_client: FileSystemRequestQueueClient) -> None:
"""Test that drop removes the entire RQ directory from disk."""
await rq_client.add_batch_of_requests([Request.from_url('https://example.com')])
Expand Down
44 changes: 42 additions & 2 deletions tests/unit/storages/test_key_value_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@

from crawlee import service_locator
from crawlee.configuration import Configuration
from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient, SqlStorageClient, StorageClient
from crawlee.storages import KeyValueStore
from crawlee.storages._storage_instance_manager import StorageInstanceManager

if TYPE_CHECKING:
from collections.abc import AsyncGenerator

from crawlee.storage_clients import StorageClient
from pathlib import Path


@pytest.fixture
Expand Down Expand Up @@ -1063,3 +1063,43 @@ async def test_name_default_not_allowed(storage_client: StorageClient) -> None:
f'it is reserved for default alias.',
):
await KeyValueStore.open(name=StorageInstanceManager._DEFAULT_STORAGE_ALIAS, storage_client=storage_client)


@pytest.mark.parametrize(
'tested_storage_client',
[
pytest.param(MemoryStorageClient(), id='tested=MemoryStorageClient'),
pytest.param(FileSystemStorageClient(), id='tested=FileSystemStorageClient'),
pytest.param(SqlStorageClient(), id='tested=SqlStorageClient'),
],
)
@pytest.mark.parametrize(
'global_storage_client',
[
pytest.param(MemoryStorageClient(), id='global=MemoryStorageClient'),
pytest.param(FileSystemStorageClient(), id='global=FileSystemStorageClient'),
pytest.param(SqlStorageClient(), id='global=SqlStorageClient'),
],
)
async def test_get_auto_saved_value_various_global_clients(
tmp_path: Path, tested_storage_client: StorageClient, global_storage_client: StorageClient
) -> None:
"""Ensure that persistence is working for all clients regardless of what is set in service locator."""
service_locator.set_configuration(
Configuration(
crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg]
purge_on_start=True,
)
)
service_locator.set_storage_client(global_storage_client)

kvs = await KeyValueStore.open(storage_client=tested_storage_client)
values_kvs = {'key': 'some_value'}
test_key = 'test_key'

autosaved_value_kvs = await kvs.get_auto_saved_value(test_key)
assert autosaved_value_kvs == {}
autosaved_value_kvs.update(values_kvs)
await kvs.persist_autosaved_values()

assert await kvs.get_value(test_key) == autosaved_value_kvs
Loading