Skip to content

Commit ac9c95f

Browse files
committed
Draft of minimizing side effects
Explicit kvs to RecoverableState
1 parent 0c5f7ca commit ac9c95f

File tree

6 files changed

+85
-20
lines changed

6 files changed

+85
-20
lines changed

src/crawlee/_utils/recoverable_state.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from pydantic import BaseModel
66

7+
from crawlee._utils.raise_if_too_many_kwargs import raise_if_too_many_kwargs
78
from crawlee.events._types import Event, EventPersistStateData
89

910
if TYPE_CHECKING:
@@ -37,6 +38,7 @@ def __init__(
3738
persistence_enabled: Literal[True, False, 'explicit_only'] = False,
3839
persist_state_kvs_name: str | None = None,
3940
persist_state_kvs_id: str | None = None,
41+
persist_state_kvs: KeyValueStore | None = None,
4042
logger: logging.Logger,
4143
) -> None:
4244
"""Initialize a new recoverable state object.
@@ -51,16 +53,28 @@ def __init__(
5153
If neither a name nor and id are supplied, the default store will be used.
5254
persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence.
5355
If neither a name nor and id are supplied, the default store will be used.
56+
persist_state_kvs: KeyValueStore to use for persistence. If not provided, a system-wide KeyValueStore will
57+
be used, based on service locator configuration.
5458
logger: A logger instance for logging operations related to state persistence
5559
"""
60+
raise_if_too_many_kwargs(persist_state_kvs_name=persist_state_kvs_name,
61+
persist_state_kvs_id=persist_state_kvs_id,
62+
key_value_store=persist_state_kvs)
63+
if not persist_state_kvs:
64+
logger.debug(
65+
'No explicit key_value_store set for recoverable state. Recovery will use a system-wide KeyValueStore '
66+
'based on service_locator configuration, potentially calling service_locator.set_storage_client in the '
67+
'process. It is recommended to initialize RecoverableState with explicit key_value_store to avoid '
68+
'global side effects.')
69+
5670
self._default_state = default_state
5771
self._state_type: type[TStateModel] = self._default_state.__class__
5872
self._state: TStateModel | None = None
5973
self._persistence_enabled = persistence_enabled
6074
self._persist_state_key = persist_state_key
6175
self._persist_state_kvs_name = persist_state_kvs_name
6276
self._persist_state_kvs_id = persist_state_kvs_id
63-
self._key_value_store: 'KeyValueStore | None' = None # noqa: UP037
77+
self._key_value_store: KeyValueStore | None = persist_state_kvs
6478
self._log = logger
6579

6680
async def initialize(self) -> TStateModel:
@@ -79,9 +93,10 @@ async def initialize(self) -> TStateModel:
7993
# Import here to avoid circular imports.
8094
from crawlee.storages._key_value_store import KeyValueStore # noqa: PLC0415
8195

82-
self._key_value_store = await KeyValueStore.open(
83-
name=self._persist_state_kvs_name, id=self._persist_state_kvs_id
84-
)
96+
if not self._key_value_store:
97+
self._key_value_store = await KeyValueStore.open(
98+
name=self._persist_state_kvs_name, id=self._persist_state_kvs_id
99+
)
85100

86101
await self._load_saved_state()
87102

src/crawlee/storage_clients/_file_system/_request_queue_client.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
metadata: RequestQueueMetadata,
9393
path_to_rq: Path,
9494
lock: asyncio.Lock,
95+
recoverable_state: RecoverableState[RequestQueueState],
9596
) -> None:
9697
"""Initialize a new instance.
9798
@@ -114,12 +115,7 @@ def __init__(
114115
self._is_empty_cache: bool | None = None
115116
"""Cache for is_empty result: None means unknown, True/False is cached state."""
116117

117-
self._state = RecoverableState[RequestQueueState](
118-
default_state=RequestQueueState(),
119-
persist_state_key=f'__RQ_STATE_{self._metadata.id}',
120-
persistence_enabled=True,
121-
logger=logger,
122-
)
118+
self._state = recoverable_state
123119
"""Recoverable state to maintain request ordering, in-progress status, and handled status."""
124120

125121
@override
@@ -136,6 +132,19 @@ def path_to_metadata(self) -> Path:
136132
"""The full path to the request queue metadata file."""
137133
return self.path_to_rq / METADATA_FILENAME
138134

135+
@classmethod
136+
async def _create_recoverable_state(cls, id: str, configuration: Configuration) -> RecoverableState:
137+
from crawlee.storage_clients import FileSystemStorageClient
138+
from crawlee.storages import KeyValueStore
139+
kvs = await KeyValueStore.open(storage_client=FileSystemStorageClient(),configuration=configuration)
140+
return RecoverableState[RequestQueueState](
141+
default_state=RequestQueueState(),
142+
persist_state_key=f'__RQ_STATE_{id}',
143+
persist_state_kvs=kvs,
144+
persistence_enabled=True,
145+
logger=logger,
146+
)
147+
139148
@classmethod
140149
async def open(
141150
cls,
@@ -194,6 +203,9 @@ async def open(
194203
metadata=metadata,
195204
path_to_rq=rq_base_path / rq_dir,
196205
lock=asyncio.Lock(),
206+
recoverable_state=await cls._create_recoverable_state(id=id,
207+
configuration=configuration),
208+
197209
)
198210
await client._state.initialize()
199211
await client._discover_existing_requests()
@@ -230,6 +242,7 @@ async def open(
230242
metadata=metadata,
231243
path_to_rq=path_to_rq,
232244
lock=asyncio.Lock(),
245+
recoverable_state=await cls._create_recoverable_state(id=metadata.id, configuration=configuration),
233246
)
234247

235248
await client._state.initialize()
@@ -254,6 +267,7 @@ async def open(
254267
metadata=metadata,
255268
path_to_rq=path_to_rq,
256269
lock=asyncio.Lock(),
270+
recoverable_state=await cls._create_recoverable_state(id=metadata.id, configuration=configuration),
257271
)
258272
await client._state.initialize()
259273
await client._update_metadata()

src/crawlee/storages/_key_value_store.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -278,13 +278,9 @@ async def get_auto_saved_value(
278278
if key in cache:
279279
return cache[key].current_value.root
280280

281-
cache[key] = recoverable_state = RecoverableState(
282-
default_state=AutosavedValue(default_value),
283-
persistence_enabled=True,
284-
persist_state_kvs_id=self.id,
285-
persist_state_key=key,
286-
logger=logger,
287-
)
281+
cache[key] = recoverable_state = RecoverableState(default_state=AutosavedValue(default_value),
282+
persist_state_key=key, persistence_enabled=True,
283+
persist_state_kvs=self, logger=logger)
288284

289285
await recoverable_state.initialize()
290286

tests/unit/conftest.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,9 @@ def _prepare_test_env() -> None:
6969
# Verify that the test environment was set up correctly.
7070
assert os.environ.get('CRAWLEE_STORAGE_DIR') == str(tmp_path)
7171

72+
# Clear global cache of autosaved values
73+
KeyValueStore._autosaved_values = {}
74+
7275
return _prepare_test_env
7376

7477

tests/unit/storage_clients/_file_system/test_fs_rq_client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66

77
import pytest
88

9-
from crawlee import Request
9+
from crawlee import Request, service_locator
1010
from crawlee.configuration import Configuration
11-
from crawlee.storage_clients import FileSystemStorageClient
11+
from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient
1212

1313
if TYPE_CHECKING:
1414
from collections.abc import AsyncGenerator
@@ -77,6 +77,11 @@ async def test_request_file_persistence(rq_client: FileSystemRequestQueueClient)
7777
assert 'url' in request_data
7878
assert request_data['url'].startswith('https://example.com/')
7979

80+
async def test_opening_rq_does_not_have_side_effect_on_service_locator(
81+
rq_client: FileSystemRequestQueueClient # noqa: ARG001
82+
) -> None:
83+
service_locator.set_storage_client(MemoryStorageClient())
84+
8085

8186
async def test_drop_removes_directory(rq_client: FileSystemRequestQueueClient) -> None:
8287
"""Test that drop removes the entire RQ directory from disk."""

tests/unit/storages/test_key_value_store.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@
1010

1111
from crawlee import service_locator
1212
from crawlee.configuration import Configuration
13+
from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient, SqlStorageClient, StorageClient
1314
from crawlee.storages import KeyValueStore
1415
from crawlee.storages._storage_instance_manager import StorageInstanceManager
1516

1617
if TYPE_CHECKING:
1718
from collections.abc import AsyncGenerator
19+
from pathlib import Path
20+
1821

19-
from crawlee.storage_clients import StorageClient
2022

2123

2224
@pytest.fixture
@@ -1063,3 +1065,33 @@ async def test_name_default_not_allowed(storage_client: StorageClient) -> None:
10631065
f'it is reserved for default alias.',
10641066
):
10651067
await KeyValueStore.open(name=StorageInstanceManager._DEFAULT_STORAGE_ALIAS, storage_client=storage_client)
1068+
1069+
@pytest.mark.parametrize('tested_storage_client', [
1070+
pytest.param(MemoryStorageClient(), id='tested=MemoryStorageClient'),
1071+
pytest.param(FileSystemStorageClient(), id='tested=FileSystemStorageClient'),
1072+
pytest.param(SqlStorageClient(), id='tested=SqlStorageClient'),
1073+
])
1074+
@pytest.mark.parametrize('global_storage_client', [
1075+
pytest.param(MemoryStorageClient(), id='global=MemoryStorageClient'),
1076+
pytest.param(FileSystemStorageClient(), id='global=FileSystemStorageClient'),
1077+
pytest.param(SqlStorageClient(), id='global=SqlStorageClient'),
1078+
])
1079+
async def test_get_auto_saved_value_various_global_clients(tmp_path: Path, tested_storage_client: StorageClient,
1080+
global_storage_client:StorageClient) -> None:
1081+
"""Ensure that persistence is working for all clients regardless of what is set in service locator."""
1082+
service_locator.set_configuration(Configuration(
1083+
crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg]
1084+
purge_on_start=True,
1085+
))
1086+
service_locator.set_storage_client(global_storage_client)
1087+
1088+
kvs = await KeyValueStore.open(storage_client=tested_storage_client)
1089+
values_kvs = {'key': 'some_value'}
1090+
test_key = 'test_key'
1091+
1092+
autosaved_value_kvs1 = await kvs.get_auto_saved_value(test_key)
1093+
assert autosaved_value_kvs1 == {}
1094+
autosaved_value_kvs1.update(values_kvs)
1095+
await kvs.persist_autosaved_values()
1096+
1097+
assert await kvs.get_value(test_key) == autosaved_value_kvs1

0 commit comments

Comments
 (0)