|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
3 |
| -from typing import TYPE_CHECKING, Any, TypeVar, overload |
| 3 | +from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload |
4 | 4 |
|
5 | 5 | from typing_extensions import override
|
6 | 6 |
|
| 7 | +from crawlee import service_container |
7 | 8 | from crawlee._utils.docs import docs_group
|
8 | 9 | from crawlee.base_storage_client._models import KeyValueStoreKeyInfo, KeyValueStoreMetadata
|
| 10 | +from crawlee.events._types import Event, EventPersistStateData |
9 | 11 | from crawlee.storages._base_storage import BaseStorage
|
10 | 12 |
|
11 | 13 | if TYPE_CHECKING:
|
12 | 14 | from collections.abc import AsyncIterator
|
13 | 15 |
|
| 16 | + from crawlee._types import JsonSerializable |
14 | 17 | from crawlee.base_storage_client import BaseStorageClient
|
15 | 18 | from crawlee.configuration import Configuration
|
16 | 19 |
|
@@ -52,6 +55,10 @@ class KeyValueStore(BaseStorage):
|
52 | 55 | ```
|
53 | 56 | """
|
54 | 57 |
|
| 58 | + # Cache for persistent (auto-saved) values |
| 59 | + _general_cache: ClassVar[dict[str, dict[str, dict[str, JsonSerializable]]]] = {} |
| 60 | + _persist_state_event_started = False |
| 61 | + |
55 | 62 | def __init__(
|
56 | 63 | self,
|
57 | 64 | id: str,
|
@@ -105,6 +112,7 @@ async def drop(self) -> None:
|
105 | 112 | from crawlee.storages._creation_management import remove_storage_from_cache
|
106 | 113 |
|
107 | 114 | await self._resource_client.delete()
|
| 115 | + self._clear_cache() |
108 | 116 | remove_storage_from_cache(storage_class=self.__class__, id=self._id, name=self._name)
|
109 | 117 |
|
110 | 118 | @overload
|
@@ -175,3 +183,70 @@ async def get_public_url(self, key: str) -> str:
|
175 | 183 | The public URL for the given key.
|
176 | 184 | """
|
177 | 185 | return await self._resource_client.get_public_url(key)
|
| 186 | + |
| 187 | + async def get_auto_saved_value( |
| 188 | + self, key: str, default_value: dict[str, JsonSerializable] | None = None |
| 189 | + ) -> dict[str, JsonSerializable]: |
| 190 | + """Gets a value from KVS that will be automatically saved on changes. |
| 191 | +
|
| 192 | + Args: |
| 193 | + key: Key of the record, to store the value. |
| 194 | + default_value: Value to be used if the record does not exist yet. Should be a dictionary. |
| 195 | +
|
| 196 | + Returns: |
| 197 | + Returns the value of the key. |
| 198 | + """ |
| 199 | + default_value = {} if default_value is None else default_value |
| 200 | + |
| 201 | + if key in self._cache: |
| 202 | + return self._cache[key] |
| 203 | + |
| 204 | + value = await self.get_value(key, default_value) |
| 205 | + |
| 206 | + if not isinstance(value, dict): |
| 207 | + raise TypeError( |
| 208 | + f'Expected dictionary for persist state value at key "{key}, but got {type(value).__name__}' |
| 209 | + ) |
| 210 | + |
| 211 | + self._cache[key] = value |
| 212 | + |
| 213 | + self._ensure_persist_event() |
| 214 | + |
| 215 | + return value |
| 216 | + |
| 217 | + @property |
| 218 | + def _cache(self) -> dict[str, dict[str, JsonSerializable]]: |
| 219 | + """Cache dictionary for storing auto-saved values indexed by store ID.""" |
| 220 | + if self._id not in self._general_cache: |
| 221 | + self._general_cache[self._id] = {} |
| 222 | + return self._general_cache[self._id] |
| 223 | + |
| 224 | + async def _persist_save(self, _event_data: EventPersistStateData | None = None) -> None: |
| 225 | + """Save cache with persistent values. Can be used in Event Manager.""" |
| 226 | + for key, value in self._cache.items(): |
| 227 | + await self.set_value(key, value) |
| 228 | + |
| 229 | + def _ensure_persist_event(self) -> None: |
| 230 | + """Setup persist state event handling if not already done.""" |
| 231 | + if self._persist_state_event_started: |
| 232 | + return |
| 233 | + |
| 234 | + event_manager = service_container.get_event_manager() |
| 235 | + event_manager.on(event=Event.PERSIST_STATE, listener=self._persist_save) |
| 236 | + self._persist_state_event_started = True |
| 237 | + |
| 238 | + def _clear_cache(self) -> None: |
| 239 | + """Clear cache with persistent values.""" |
| 240 | + self._cache.clear() |
| 241 | + |
| 242 | + def _drop_persist_state_event(self) -> None: |
| 243 | + """Off event_manager listener and drop event status.""" |
| 244 | + if self._persist_state_event_started: |
| 245 | + event_manager = service_container.get_event_manager() |
| 246 | + event_manager.off(event=Event.PERSIST_STATE, listener=self._persist_save) |
| 247 | + self._persist_state_event_started = False |
| 248 | + |
| 249 | + async def persist_autosaved_values(self) -> None: |
| 250 | + """Force persistent values to be saved without waiting for an event in Event Manager.""" |
| 251 | + if self._persist_state_event_started: |
| 252 | + await self._persist_save() |
0 commit comments