9
9
10
10
if TYPE_CHECKING :
11
11
import logging
12
+ from collections .abc import Callable , Coroutine
12
13
13
- from crawlee .storages . _key_value_store import KeyValueStore
14
+ from crawlee .storages import KeyValueStore
14
15
15
16
TStateModel = TypeVar ('TStateModel' , bound = BaseModel )
16
17
@@ -38,7 +39,7 @@ def __init__(
38
39
persistence_enabled : Literal [True , False , 'explicit_only' ] = False ,
39
40
persist_state_kvs_name : str | None = None ,
40
41
persist_state_kvs_id : str | None = None ,
41
- persist_state_kvs : KeyValueStore | None = None ,
42
+ persist_state_kvs_factory : Callable [[], Coroutine [ None , None , KeyValueStore ]] | None = None ,
42
43
logger : logging .Logger ,
43
44
) -> None :
44
45
"""Initialize a new recoverable state object.
@@ -53,28 +54,40 @@ def __init__(
53
54
If neither a name nor and id are supplied, the default store will be used.
54
55
persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence.
55
56
If neither a name nor and id are supplied, the default store will be used.
56
- persist_state_kvs: KeyValueStore to use for persistence. If not provided, a system-wide KeyValueStore will
57
- be used, based on service locator configuration.
57
+ persist_state_kvs_factory: Factory that can be awaited to create KeyValueStore to use for persistence. If
58
+ not provided, a system-wide KeyValueStore will be used, based on service locator configuration.
58
59
logger: A logger instance for logging operations related to state persistence
59
60
"""
60
- raise_if_too_many_kwargs (persist_state_kvs_name = persist_state_kvs_name ,
61
- persist_state_kvs_id = persist_state_kvs_id ,
62
- key_value_store = persist_state_kvs )
63
- if not persist_state_kvs :
61
+ raise_if_too_many_kwargs (
62
+ persist_state_kvs_name = persist_state_kvs_name ,
63
+ persist_state_kvs_id = persist_state_kvs_id ,
64
+ persist_state_kvs_factory = persist_state_kvs_factory ,
65
+ )
66
+ if not persist_state_kvs_factory :
64
67
logger .debug (
65
68
'No explicit key_value_store set for recoverable state. Recovery will use a system-wide KeyValueStore '
66
69
'based on service_locator configuration, potentially calling service_locator.set_storage_client in the '
67
70
'process. It is recommended to initialize RecoverableState with explicit key_value_store to avoid '
68
- 'global side effects.' )
71
+ 'global side effects.'
72
+ )
69
73
70
74
self ._default_state = default_state
71
75
self ._state_type : type [TStateModel ] = self ._default_state .__class__
72
76
self ._state : TStateModel | None = None
73
77
self ._persistence_enabled = persistence_enabled
74
78
self ._persist_state_key = persist_state_key
75
- self ._persist_state_kvs_name = persist_state_kvs_name
76
- self ._persist_state_kvs_id = persist_state_kvs_id
77
- self ._key_value_store : KeyValueStore | None = persist_state_kvs
79
+ if persist_state_kvs_factory is None :
80
+
81
+ async def kvs_factory () -> KeyValueStore :
82
+ from crawlee .storages import KeyValueStore # noqa: PLC0415 avoid circular import
83
+
84
+ return await KeyValueStore .open (name = persist_state_kvs_name , id = persist_state_kvs_id )
85
+
86
+ self ._persist_state_kvs_factory = kvs_factory
87
+ else :
88
+ self ._persist_state_kvs_factory = persist_state_kvs_factory
89
+
90
+ self ._key_value_store : KeyValueStore | None = None
78
91
self ._log = logger
79
92
80
93
async def initialize (self ) -> TStateModel :
@@ -91,12 +104,8 @@ async def initialize(self) -> TStateModel:
91
104
return self .current_value
92
105
93
106
# Import here to avoid circular imports.
94
- from crawlee .storages ._key_value_store import KeyValueStore # noqa: PLC0415
95
107
96
- if not self ._key_value_store :
97
- self ._key_value_store = await KeyValueStore .open (
98
- name = self ._persist_state_kvs_name , id = self ._persist_state_kvs_id
99
- )
108
+ self ._key_value_store = await self ._persist_state_kvs_factory ()
100
109
101
110
await self ._load_saved_state ()
102
111
0 commit comments