Skip to content

Commit 868b41e

Browse files
Mantisusvdusek
andauthored
feat: Add use_state context method (#682)
### Description Add `use_state` context helper method with auto-save behavior inspired by TS's realization of `crawlee`. Adds a `get_auto_saved_value` method for `kvs`. Working with internal cache and saving data on `PERSIST_STATE` event (At this point, it's more of a gut check and synchronization with the team. It will probably require some more refinement) ### Issues - Closes: #191 ### Testing For correct testing, it is necessary to use event_manager in tests for KeyValueStore, therefore I have added a new fixture. ### Checklist - [x] CI passed --------- Co-authored-by: Vlada Dusek <[email protected]>
1 parent 13bb400 commit 868b41e

File tree

8 files changed

+227
-2
lines changed

8 files changed

+227
-2
lines changed

src/crawlee/_types.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,20 @@ def __call__(
300300
) -> Coroutine[None, None, HttpResponse]: ...
301301

302302

303+
class UseStateFunction(Protocol):
304+
"""Type of a function for performing use state.
305+
306+
Warning:
307+
This is an experimental feature. The behavior and interface may change in future versions.
308+
"""
309+
310+
def __call__(
311+
self,
312+
key: str,
313+
default_value: dict[str, JsonSerializable] | None = None,
314+
) -> Coroutine[None, None, dict[str, JsonSerializable]]: ...
315+
316+
303317
T = TypeVar('T')
304318

305319

@@ -347,6 +361,7 @@ class BasicCrawlingContext:
347361
send_request: SendRequestFunction
348362
add_requests: AddRequestsFunction
349363
push_data: PushDataFunction
364+
use_state: UseStateFunction
350365
get_key_value_store: GetKeyValueStoreFromRequestHandlerFunction
351366
log: logging.Logger
352367

src/crawlee/basic_crawler/_basic_crawler.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,8 @@ def sigint_handler() -> None:
478478
self._running = False
479479
self._has_finished_before = True
480480

481+
await self._save_crawler_state()
482+
481483
final_statistics = self._statistics.calculate()
482484
self._logger.info(f'Final request statistics:\n{final_statistics.to_table()}')
483485

@@ -532,6 +534,16 @@ async def add_requests(
532534
wait_for_all_requests_to_be_added_timeout=wait_for_all_requests_to_be_added_timeout,
533535
)
534536

537+
async def _use_state(
538+
self, key: str, default_value: dict[str, JsonSerializable] | None = None
539+
) -> dict[str, JsonSerializable]:
540+
store = await self.get_key_value_store()
541+
return await store.get_auto_saved_value(key, default_value)
542+
543+
async def _save_crawler_state(self) -> None:
544+
store = await self.get_key_value_store()
545+
await store.persist_autosaved_values()
546+
535547
async def get_data(
536548
self,
537549
dataset_id: str | None = None,
@@ -953,6 +965,7 @@ async def __run_task_function(self) -> None:
953965
add_requests=result.add_requests,
954966
push_data=result.push_data,
955967
get_key_value_store=result.get_key_value_store,
968+
use_state=self._use_state,
956969
log=self._logger,
957970
)
958971

src/crawlee/playwright_crawler/_playwright_crawler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ async def _open_page(self, context: BasicCrawlingContext) -> AsyncGenerator[Play
136136
add_requests=context.add_requests,
137137
send_request=context.send_request,
138138
push_data=context.push_data,
139+
use_state=context.use_state,
139140
proxy_info=context.proxy_info,
140141
get_key_value_store=context.get_key_value_store,
141142
log=context.log,
@@ -225,6 +226,7 @@ async def enqueue_links(
225226
add_requests=context.add_requests,
226227
send_request=context.send_request,
227228
push_data=context.push_data,
229+
use_state=context.use_state,
228230
proxy_info=context.proxy_info,
229231
get_key_value_store=context.get_key_value_store,
230232
log=context.log,

src/crawlee/storages/_key_value_store.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Any, TypeVar, overload
3+
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload
44

55
from typing_extensions import override
66

7+
from crawlee import service_container
78
from crawlee._utils.docs import docs_group
89
from crawlee.base_storage_client._models import KeyValueStoreKeyInfo, KeyValueStoreMetadata
10+
from crawlee.events._types import Event, EventPersistStateData
911
from crawlee.storages._base_storage import BaseStorage
1012

1113
if TYPE_CHECKING:
1214
from collections.abc import AsyncIterator
1315

16+
from crawlee._types import JsonSerializable
1417
from crawlee.base_storage_client import BaseStorageClient
1518
from crawlee.configuration import Configuration
1619

@@ -52,6 +55,10 @@ class KeyValueStore(BaseStorage):
5255
```
5356
"""
5457

58+
# Cache for persistent (auto-saved) values
59+
_general_cache: ClassVar[dict[str, dict[str, dict[str, JsonSerializable]]]] = {}
60+
_persist_state_event_started = False
61+
5562
def __init__(
5663
self,
5764
id: str,
@@ -105,6 +112,7 @@ async def drop(self) -> None:
105112
from crawlee.storages._creation_management import remove_storage_from_cache
106113

107114
await self._resource_client.delete()
115+
self._clear_cache()
108116
remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name)
109117

110118
@overload
@@ -175,3 +183,70 @@ async def get_public_url(self, key: str) -> str:
175183
The public URL for the given key.
176184
"""
177185
return await self._resource_client.get_public_url(key)
186+
187+
async def get_auto_saved_value(
188+
self, key: str, default_value: dict[str, JsonSerializable] | None = None
189+
) -> dict[str, JsonSerializable]:
190+
"""Gets a value from KVS that will be automatically saved on changes.
191+
192+
Args:
193+
key: Key of the record, to store the value.
194+
default_value: Value to be used if the record does not exist yet. Should be a dictionary.
195+
196+
Returns:
197+
Returns the value of the key.
198+
"""
199+
default_value = {} if default_value is None else default_value
200+
201+
if key in self._cache:
202+
return self._cache[key]
203+
204+
value = await self.get_value(key, default_value)
205+
206+
if not isinstance(value, dict):
207+
raise TypeError(
208+
f'Expected dictionary for persist state value at key "{key}, but got {type(value).__name__}'
209+
)
210+
211+
self._cache[key] = value
212+
213+
self._ensure_persist_event()
214+
215+
return value
216+
217+
@property
218+
def _cache(self) -> dict[str, dict[str, JsonSerializable]]:
219+
"""Cache dictionary for storing auto-saved values indexed by store ID."""
220+
if self._id not in self._general_cache:
221+
self._general_cache[self._id] = {}
222+
return self._general_cache[self._id]
223+
224+
async def _persist_save(self, _event_data: EventPersistStateData | None = None) -> None:
225+
"""Save cache with persistent values. Can be used in Event Manager."""
226+
for key, value in self._cache.items():
227+
await self.set_value(key, value)
228+
229+
def _ensure_persist_event(self) -> None:
230+
"""Setup persist state event handling if not already done."""
231+
if self._persist_state_event_started:
232+
return
233+
234+
event_manager = service_container.get_event_manager()
235+
event_manager.on(event=Event.PERSIST_STATE, listener=self._persist_save)
236+
self._persist_state_event_started = True
237+
238+
def _clear_cache(self) -> None:
239+
"""Clear cache with persistent values."""
240+
self._cache.clear()
241+
242+
def _drop_persist_state_event(self) -> None:
243+
"""Off event_manager listener and drop event status."""
244+
if self._persist_state_event_started:
245+
event_manager = service_container.get_event_manager()
246+
event_manager.off(event=Event.PERSIST_STATE, listener=self._persist_save)
247+
self._persist_state_event_started = False
248+
249+
async def persist_autosaved_values(self) -> None:
250+
"""Force persistent values to be saved without waiting for an event in Event Manager."""
251+
if self._persist_state_event_started:
252+
await self._persist_save()

tests/unit/basic_crawler/test_basic_crawler.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,20 @@
2424
from crawlee.storages import Dataset, KeyValueStore, RequestList, RequestQueue
2525

2626
if TYPE_CHECKING:
27-
from collections.abc import Sequence
27+
from collections.abc import AsyncGenerator, Sequence
2828

2929
import respx
3030
from yarl import URL
3131

32+
from crawlee._types import JsonSerializable
33+
34+
35+
@pytest.fixture
36+
async def key_value_store() -> AsyncGenerator[KeyValueStore, None]:
37+
kvs = await KeyValueStore.open()
38+
yield kvs
39+
await kvs.drop()
40+
3241

3342
async def test_processes_requests() -> None:
3443
crawler = BasicCrawler(request_provider=RequestList(['http://a.com/', 'http://b.com/', 'http://c.com/']))
@@ -677,6 +686,63 @@ async def handler(context: BasicCrawlingContext) -> None:
677686
assert (await store.get_value('foo')) == 'bar'
678687

679688

689+
async def test_context_use_state(key_value_store: KeyValueStore) -> None:
690+
crawler = BasicCrawler()
691+
692+
@crawler.router.default_handler
693+
async def handler(context: BasicCrawlingContext) -> None:
694+
await context.use_state('state', {'hello': 'world'})
695+
696+
await crawler.run(['https://hello.world'])
697+
698+
store = await crawler.get_key_value_store()
699+
700+
assert (await store.get_value('state')) == {'hello': 'world'}
701+
702+
703+
async def test_context_handlers_use_state(key_value_store: KeyValueStore) -> None:
704+
state_in_handler_one: dict[str, JsonSerializable] = {}
705+
state_in_handler_two: dict[str, JsonSerializable] = {}
706+
state_in_handler_three: dict[str, JsonSerializable] = {}
707+
708+
crawler = BasicCrawler()
709+
710+
@crawler.router.handler('one')
711+
async def handler_one(context: BasicCrawlingContext) -> None:
712+
state = await context.use_state('state', {'hello': 'world'})
713+
state_in_handler_one.update(state)
714+
state['hello'] = 'new_world'
715+
await context.add_requests([Request.from_url('https://crawlee.dev/docs/quick-start', label='two')])
716+
717+
@crawler.router.handler('two')
718+
async def handler_two(context: BasicCrawlingContext) -> None:
719+
state = await context.use_state('state', {'hello': 'world'})
720+
state_in_handler_two.update(state)
721+
state['hello'] = 'last_world'
722+
723+
@crawler.router.handler('three')
724+
async def handler_three(context: BasicCrawlingContext) -> None:
725+
state = await context.use_state('state', {'hello': 'world'})
726+
state_in_handler_three.update(state)
727+
728+
await crawler.run([Request.from_url('https://crawlee.dev/', label='one')])
729+
await crawler.run([Request.from_url('https://crawlee.dev/docs/examples', label='three')])
730+
731+
# The state in handler_one must match the default state
732+
assert state_in_handler_one == {'hello': 'world'}
733+
734+
# The state in handler_two must match the state updated in handler_one
735+
assert state_in_handler_two == {'hello': 'new_world'}
736+
737+
# The state in handler_three must match the final state updated in previous run
738+
assert state_in_handler_three == {'hello': 'last_world'}
739+
740+
store = await crawler.get_key_value_store()
741+
742+
# The state in the KVS must match with the last set state
743+
assert (await store.get_value('state')) == {'hello': 'last_world'}
744+
745+
680746
async def test_max_requests_per_crawl(httpbin: URL) -> None:
681747
start_urls = [
682748
str(httpbin / '1'),

tests/unit/basic_crawler/test_context_pipeline.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ async def test_calls_consumer_without_middleware() -> None:
3838
session=Session(),
3939
proxy_info=AsyncMock(),
4040
push_data=AsyncMock(),
41+
use_state=AsyncMock(),
4142
get_key_value_store=AsyncMock(),
4243
log=logging.getLogger(),
4344
)
@@ -64,6 +65,7 @@ async def middleware_a(context: BasicCrawlingContext) -> AsyncGenerator[Enhanced
6465
session=context.session,
6566
proxy_info=AsyncMock(),
6667
push_data=AsyncMock(),
68+
use_state=AsyncMock(),
6769
get_key_value_store=AsyncMock(),
6870
log=logging.getLogger(),
6971
)
@@ -80,6 +82,7 @@ async def middleware_b(context: EnhancedCrawlingContext) -> AsyncGenerator[MoreE
8082
session=context.session,
8183
proxy_info=AsyncMock(),
8284
push_data=AsyncMock(),
85+
use_state=AsyncMock(),
8386
get_key_value_store=AsyncMock(),
8487
log=logging.getLogger(),
8588
)
@@ -94,6 +97,7 @@ async def middleware_b(context: EnhancedCrawlingContext) -> AsyncGenerator[MoreE
9497
session=Session(),
9598
proxy_info=AsyncMock(),
9699
push_data=AsyncMock(),
100+
use_state=AsyncMock(),
97101
get_key_value_store=AsyncMock(),
98102
log=logging.getLogger(),
99103
)
@@ -119,6 +123,7 @@ async def test_wraps_consumer_errors() -> None:
119123
session=Session(),
120124
proxy_info=AsyncMock(),
121125
push_data=AsyncMock(),
126+
use_state=AsyncMock(),
122127
get_key_value_store=AsyncMock(),
123128
log=logging.getLogger(),
124129
)
@@ -147,6 +152,7 @@ async def step_2(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingC
147152
session=Session(),
148153
proxy_info=AsyncMock(),
149154
push_data=AsyncMock(),
155+
use_state=AsyncMock(),
150156
get_key_value_store=AsyncMock(),
151157
log=logging.getLogger(),
152158
)
@@ -178,6 +184,7 @@ async def step_2(context: BasicCrawlingContext) -> AsyncGenerator[BasicCrawlingC
178184
session=Session(),
179185
proxy_info=AsyncMock(),
180186
push_data=AsyncMock(),
187+
use_state=AsyncMock(),
181188
get_key_value_store=AsyncMock(),
182189
log=logging.getLogger(),
183190
)

0 commit comments

Comments
 (0)