Skip to content

Commit 537ed1a

Browse files
committed
Use factory method istead of the explicit kvs
1 parent ac9c95f commit 537ed1a

File tree

8 files changed

+88
-53
lines changed

8 files changed

+88
-53
lines changed

src/crawlee/_utils/recoverable_state.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,9 @@
99

1010
if TYPE_CHECKING:
1111
import logging
12+
from collections.abc import Callable, Coroutine
1213

13-
from crawlee.storages._key_value_store import KeyValueStore
14+
from crawlee.storages import KeyValueStore
1415

1516
TStateModel = TypeVar('TStateModel', bound=BaseModel)
1617

@@ -38,7 +39,7 @@ def __init__(
3839
persistence_enabled: Literal[True, False, 'explicit_only'] = False,
3940
persist_state_kvs_name: str | None = None,
4041
persist_state_kvs_id: str | None = None,
41-
persist_state_kvs: KeyValueStore | None = None,
42+
persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None,
4243
logger: logging.Logger,
4344
) -> None:
4445
"""Initialize a new recoverable state object.
@@ -53,28 +54,40 @@ def __init__(
5354
If neither a name nor and id are supplied, the default store will be used.
5455
persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence.
5556
If neither a name nor and id are supplied, the default store will be used.
56-
persist_state_kvs: KeyValueStore to use for persistence. If not provided, a system-wide KeyValueStore will
57-
be used, based on service locator configuration.
57+
persist_state_kvs_factory: Factory that can be awaited to create KeyValueStore to use for persistence. If
58+
not provided, a system-wide KeyValueStore will be used, based on service locator configuration.
5859
logger: A logger instance for logging operations related to state persistence
5960
"""
60-
raise_if_too_many_kwargs(persist_state_kvs_name=persist_state_kvs_name,
61-
persist_state_kvs_id=persist_state_kvs_id,
62-
key_value_store=persist_state_kvs)
63-
if not persist_state_kvs:
61+
raise_if_too_many_kwargs(
62+
persist_state_kvs_name=persist_state_kvs_name,
63+
persist_state_kvs_id=persist_state_kvs_id,
64+
persist_state_kvs_factory=persist_state_kvs_factory,
65+
)
66+
if not persist_state_kvs_factory:
6467
logger.debug(
6568
'No explicit key_value_store set for recoverable state. Recovery will use a system-wide KeyValueStore '
6669
'based on service_locator configuration, potentially calling service_locator.set_storage_client in the '
6770
'process. It is recommended to initialize RecoverableState with explicit key_value_store to avoid '
68-
'global side effects.')
71+
'global side effects.'
72+
)
6973

7074
self._default_state = default_state
7175
self._state_type: type[TStateModel] = self._default_state.__class__
7276
self._state: TStateModel | None = None
7377
self._persistence_enabled = persistence_enabled
7478
self._persist_state_key = persist_state_key
75-
self._persist_state_kvs_name = persist_state_kvs_name
76-
self._persist_state_kvs_id = persist_state_kvs_id
77-
self._key_value_store: KeyValueStore | None = persist_state_kvs
79+
if persist_state_kvs_factory is None:
80+
81+
async def kvs_factory() -> KeyValueStore:
82+
from crawlee.storages import KeyValueStore # noqa: PLC0415 avoid circular import
83+
84+
return await KeyValueStore.open(name=persist_state_kvs_name, id=persist_state_kvs_id)
85+
86+
self._persist_state_kvs_factory = kvs_factory
87+
else:
88+
self._persist_state_kvs_factory = persist_state_kvs_factory
89+
90+
self._key_value_store: KeyValueStore | None = None
7891
self._log = logger
7992

8093
async def initialize(self) -> TStateModel:
@@ -91,12 +104,8 @@ async def initialize(self) -> TStateModel:
91104
return self.current_value
92105

93106
# Import here to avoid circular imports.
94-
from crawlee.storages._key_value_store import KeyValueStore # noqa: PLC0415
95107

96-
if not self._key_value_store:
97-
self._key_value_store = await KeyValueStore.open(
98-
name=self._persist_state_kvs_name, id=self._persist_state_kvs_id
99-
)
108+
self._key_value_store = await self._persist_state_kvs_factory()
100109

101110
await self._load_saved_state()
102111

src/crawlee/statistics/_statistics.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,11 @@
1717
from crawlee.statistics._error_tracker import ErrorTracker
1818

1919
if TYPE_CHECKING:
20+
from collections.abc import Callable, Coroutine
2021
from types import TracebackType
2122

23+
from crawlee.storages import KeyValueStore
24+
2225
TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState)
2326
TNewStatisticsState = TypeVar('TNewStatisticsState', bound=StatisticsState, default=StatisticsState)
2427
logger = getLogger(__name__)
@@ -70,6 +73,7 @@ def __init__(
7073
persistence_enabled: bool | Literal['explicit_only'] = False,
7174
persist_state_kvs_name: str | None = None,
7275
persist_state_key: str | None = None,
76+
persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None,
7377
log_message: str = 'Statistics',
7478
periodic_message_logger: Logger | None = None,
7579
log_interval: timedelta = timedelta(minutes=1),
@@ -95,6 +99,7 @@ def __init__(
9599
persist_state_key=persist_state_key or f'SDK_CRAWLER_STATISTICS_{self._id}',
96100
persistence_enabled=persistence_enabled,
97101
persist_state_kvs_name=persist_state_kvs_name,
102+
persist_state_kvs_factory=persist_state_kvs_factory,
98103
logger=logger,
99104
)
100105

@@ -110,8 +115,8 @@ def replace_state_model(self, state_model: type[TNewStatisticsState]) -> Statist
110115
"""Create near copy of the `Statistics` with replaced `state_model`."""
111116
new_statistics: Statistics[TNewStatisticsState] = Statistics(
112117
persistence_enabled=self._state._persistence_enabled, # noqa: SLF001
113-
persist_state_kvs_name=self._state._persist_state_kvs_name, # noqa: SLF001
114118
persist_state_key=self._state._persist_state_key, # noqa: SLF001
119+
persist_state_kvs_factory=self._state._persist_state_kvs_factory, # noqa: SLF001
115120
log_message=self._log_message,
116121
periodic_message_logger=self._periodic_message_logger,
117122
state_model=state_model,

src/crawlee/storage_clients/_file_system/_request_queue_client.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from collections.abc import Sequence
3232

3333
from crawlee.configuration import Configuration
34+
from crawlee.storages import KeyValueStore
3435

3536
logger = getLogger(__name__)
3637

@@ -134,13 +135,16 @@ def path_to_metadata(self) -> Path:
134135

135136
@classmethod
136137
async def _create_recoverable_state(cls, id: str, configuration: Configuration) -> RecoverableState:
137-
from crawlee.storage_clients import FileSystemStorageClient
138-
from crawlee.storages import KeyValueStore
139-
kvs = await KeyValueStore.open(storage_client=FileSystemStorageClient(),configuration=configuration)
138+
async def kvs_factory() -> KeyValueStore:
139+
from crawlee.storage_clients import FileSystemStorageClient # noqa: PLC0415 avoid circular import
140+
from crawlee.storages import KeyValueStore # noqa: PLC0415 avoid circular import
141+
142+
return await KeyValueStore.open(storage_client=FileSystemStorageClient(), configuration=configuration)
143+
140144
return RecoverableState[RequestQueueState](
141145
default_state=RequestQueueState(),
142146
persist_state_key=f'__RQ_STATE_{id}',
143-
persist_state_kvs=kvs,
147+
persist_state_kvs_factory=kvs_factory,
144148
persistence_enabled=True,
145149
logger=logger,
146150
)
@@ -203,9 +207,9 @@ async def open(
203207
metadata=metadata,
204208
path_to_rq=rq_base_path / rq_dir,
205209
lock=asyncio.Lock(),
206-
recoverable_state=await cls._create_recoverable_state(id=id,
207-
configuration=configuration),
208-
210+
recoverable_state=await cls._create_recoverable_state(
211+
id=id, configuration=configuration
212+
),
209213
)
210214
await client._state.initialize()
211215
await client._discover_existing_requests()

src/crawlee/storages/_key_value_store.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,9 +278,16 @@ async def get_auto_saved_value(
278278
if key in cache:
279279
return cache[key].current_value.root
280280

281-
cache[key] = recoverable_state = RecoverableState(default_state=AutosavedValue(default_value),
282-
persist_state_key=key, persistence_enabled=True,
283-
persist_state_kvs=self, logger=logger)
281+
async def kvs_factory() -> KeyValueStore:
282+
return self
283+
284+
cache[key] = recoverable_state = RecoverableState(
285+
default_state=AutosavedValue(default_value),
286+
persist_state_key=key,
287+
persistence_enabled=True,
288+
persist_state_kvs_factory=kvs_factory,
289+
logger=logger,
290+
)
284291

285292
await recoverable_state.initialize()
286293

tests/unit/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from crawlee.fingerprint_suite._browserforge_adapter import get_available_header_network
1818
from crawlee.http_clients import CurlImpersonateHttpClient, HttpxHttpClient, ImpitHttpClient
1919
from crawlee.proxy_configuration import ProxyInfo
20+
from crawlee.statistics import Statistics
2021
from crawlee.storages import KeyValueStore
2122
from tests.unit.server import TestServer, app, serve_in_thread
2223

@@ -69,8 +70,9 @@ def _prepare_test_env() -> None:
6970
# Verify that the test environment was set up correctly.
7071
assert os.environ.get('CRAWLEE_STORAGE_DIR') == str(tmp_path)
7172

72-
# Clear global cache of autosaved values
73+
# Reset global class variables to ensure test isolation.
7374
KeyValueStore._autosaved_values = {}
75+
Statistics._Statistics__next_id = 0 # type:ignore[attr-defined] # Mangled attribute
7476

7577
return _prepare_test_env
7678

tests/unit/crawlers/_adaptive_playwright/test_adaptive_playwright_crawler.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,6 @@ async def test_adaptive_playwright_crawler_statistics_in_init() -> None:
493493
assert type(crawler._statistics.state) is AdaptivePlaywrightCrawlerStatisticState
494494

495495
assert crawler._statistics._state._persistence_enabled == persistence_enabled
496-
assert crawler._statistics._state._persist_state_kvs_name == persist_state_kvs_name
497496
assert crawler._statistics._state._persist_state_key == persist_state_key
498497

499498
assert crawler._statistics._log_message == log_message

tests/unit/storage_clients/_file_system/test_fs_rq_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ async def test_request_file_persistence(rq_client: FileSystemRequestQueueClient)
7777
assert 'url' in request_data
7878
assert request_data['url'].startswith('https://example.com/')
7979

80+
8081
async def test_opening_rq_does_not_have_side_effect_on_service_locator(
81-
rq_client: FileSystemRequestQueueClient # noqa: ARG001
82+
rq_client: FileSystemRequestQueueClient, # noqa: ARG001
8283
) -> None:
8384
service_locator.set_storage_client(MemoryStorageClient())
8485

tests/unit/storages/test_key_value_store.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919
from pathlib import Path
2020

2121

22-
23-
2422
@pytest.fixture
2523
async def kvs(
2624
storage_client: StorageClient,
@@ -1066,32 +1064,42 @@ async def test_name_default_not_allowed(storage_client: StorageClient) -> None:
10661064
):
10671065
await KeyValueStore.open(name=StorageInstanceManager._DEFAULT_STORAGE_ALIAS, storage_client=storage_client)
10681066

1069-
@pytest.mark.parametrize('tested_storage_client', [
1070-
pytest.param(MemoryStorageClient(), id='tested=MemoryStorageClient'),
1071-
pytest.param(FileSystemStorageClient(), id='tested=FileSystemStorageClient'),
1072-
pytest.param(SqlStorageClient(), id='tested=SqlStorageClient'),
1073-
])
1074-
@pytest.mark.parametrize('global_storage_client', [
1075-
pytest.param(MemoryStorageClient(), id='global=MemoryStorageClient'),
1076-
pytest.param(FileSystemStorageClient(), id='global=FileSystemStorageClient'),
1077-
pytest.param(SqlStorageClient(), id='global=SqlStorageClient'),
1078-
])
1079-
async def test_get_auto_saved_value_various_global_clients(tmp_path: Path, tested_storage_client: StorageClient,
1080-
global_storage_client:StorageClient) -> None:
1067+
1068+
@pytest.mark.parametrize(
1069+
'tested_storage_client',
1070+
[
1071+
pytest.param(MemoryStorageClient(), id='tested=MemoryStorageClient'),
1072+
pytest.param(FileSystemStorageClient(), id='tested=FileSystemStorageClient'),
1073+
pytest.param(SqlStorageClient(), id='tested=SqlStorageClient'),
1074+
],
1075+
)
1076+
@pytest.mark.parametrize(
1077+
'global_storage_client',
1078+
[
1079+
pytest.param(MemoryStorageClient(), id='global=MemoryStorageClient'),
1080+
pytest.param(FileSystemStorageClient(), id='global=FileSystemStorageClient'),
1081+
pytest.param(SqlStorageClient(), id='global=SqlStorageClient'),
1082+
],
1083+
)
1084+
async def test_get_auto_saved_value_various_global_clients(
1085+
tmp_path: Path, tested_storage_client: StorageClient, global_storage_client: StorageClient
1086+
) -> None:
10811087
"""Ensure that persistence is working for all clients regardless of what is set in service locator."""
1082-
service_locator.set_configuration(Configuration(
1083-
crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg]
1084-
purge_on_start=True,
1085-
))
1088+
service_locator.set_configuration(
1089+
Configuration(
1090+
crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg]
1091+
purge_on_start=True,
1092+
)
1093+
)
10861094
service_locator.set_storage_client(global_storage_client)
10871095

10881096
kvs = await KeyValueStore.open(storage_client=tested_storage_client)
10891097
values_kvs = {'key': 'some_value'}
10901098
test_key = 'test_key'
10911099

1092-
autosaved_value_kvs1 = await kvs.get_auto_saved_value(test_key)
1093-
assert autosaved_value_kvs1 == {}
1094-
autosaved_value_kvs1.update(values_kvs)
1100+
autosaved_value_kvs = await kvs.get_auto_saved_value(test_key)
1101+
assert autosaved_value_kvs == {}
1102+
autosaved_value_kvs.update(values_kvs)
10951103
await kvs.persist_autosaved_values()
10961104

1097-
assert await kvs.get_value(test_key) == autosaved_value_kvs1
1105+
assert await kvs.get_value(test_key) == autosaved_value_kvs

0 commit comments

Comments
 (0)