Skip to content

Commit faa5974

Browse files
committed
Refactor storages to allow force_cloud feature
1 parent 9edd205 commit faa5974

File tree

5 files changed

+99
-40
lines changed

5 files changed

+99
-40
lines changed

src/crawlee/storages/_base_storage.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
from __future__ import annotations
22

33
from abc import ABC, abstractmethod
4+
from typing import TYPE_CHECKING
5+
6+
if TYPE_CHECKING:
7+
from crawlee.base_storage_client import BaseStorageClient
8+
from crawlee.configuration import Configuration
49

510

611
class BaseStorage(ABC):
@@ -18,14 +23,24 @@ def name(self) -> str | None:
1823

1924
@classmethod
2025
@abstractmethod
21-
async def open(cls, *, id: str | None = None, name: str | None = None) -> BaseStorage:
26+
async def open(
27+
cls,
28+
*,
29+
id: str | None = None,
30+
name: str | None = None,
31+
configuration: Configuration | None = None,
32+
storage_client: BaseStorageClient | None = None,
33+
) -> BaseStorage:
2234
"""Open a storage, either restore existing or create a new one.
2335
2436
Args:
2537
id: The storage ID.
2638
name: The storage name.
39+
configuration: Configuration object used during the storage creation or restoration process.
40+
storage_client: Underlying storage client to use. If not provided, the default global storage client
41+
from the service locator will be used.
2742
"""
2843

2944
@abstractmethod
3045
async def drop(self) -> None:
31-
"""Drop the storage. Remove it from underlying storage and delete from cache."""
46+
"""Drops the storage, removing it from the underlying storage client and clearing the cache."""

src/crawlee/storages/_creation_management.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import asyncio
44
from typing import TYPE_CHECKING, TypeVar
55

6-
from crawlee import service_container
76
from crawlee.memory_storage_client import MemoryStorageClient
87
from crawlee.storages import Dataset, KeyValueStore, RequestQueue
98

@@ -124,18 +123,17 @@ async def open_storage(
124123
storage_class: type[TResource],
125124
id: str | None = None,
126125
name: str | None = None,
126+
configuration: Configuration,
127+
storage_client: BaseStorageClient,
127128
) -> TResource:
128129
"""Open either a new storage or restore an existing one and return it."""
129-
config = service_container.get_configuration()
130-
storage_client = service_container.get_storage_client()
131-
132130
# Try to restore the storage from cache by name
133131
if name:
134132
cached_storage = _get_from_cache_by_name(storage_class=storage_class, name=name)
135133
if cached_storage:
136134
return cached_storage
137135

138-
default_id = _get_default_storage_id(config, storage_class)
136+
default_id = _get_default_storage_id(configuration, storage_class)
139137

140138
if not id and not name:
141139
id = default_id
@@ -150,7 +148,7 @@ async def open_storage(
150148
return cached_storage
151149

152150
# Purge on start if configured
153-
if config.purge_on_start:
151+
if configuration.purge_on_start:
154152
await storage_client.purge_on_start()
155153

156154
# Lock and create new storage
@@ -169,7 +167,7 @@ async def open_storage(
169167
resource_collection_client = _get_resource_collection_client(storage_class, storage_client)
170168
storage_info = await resource_collection_client.get_or_create(name=name)
171169

172-
storage = storage_class(id=storage_info.id, name=storage_info.name)
170+
storage = storage_class(id=storage_info.id, name=storage_info.name, storage_client=storage_client)
173171

174172
# Cache the storage by ID and name
175173
_add_to_cache_by_id(storage.id, storage)

src/crawlee/storages/_dataset.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from collections.abc import AsyncIterator, Callable
2121

2222
from crawlee._types import JsonSerializable, PushDataKwargs
23+
from crawlee.base_storage_client import BaseStorageClient
2324
from crawlee.base_storage_client._models import DatasetItemsListPage
24-
25+
from crawlee.configuration import Configuration
2526

2627
logger = logging.getLogger(__name__)
2728

@@ -192,13 +193,11 @@ class Dataset(BaseStorage):
192193
_EFFECTIVE_LIMIT_SIZE = _MAX_PAYLOAD_SIZE - (_MAX_PAYLOAD_SIZE * _SAFETY_BUFFER_PERCENT)
193194
"""Calculated payload limit considering safety buffer."""
194195

195-
def __init__(self, id: str, name: str | None) -> None:
196-
storage_client = service_container.get_storage_client()
197-
196+
def __init__(self, id: str, name: str | None, storage_client: BaseStorageClient) -> None:
198197
self._id = id
199198
self._name = name
200199

201-
# Get resource clients from storage client
200+
# Get resource clients from the storage client.
202201
self._resource_client = storage_client.dataset(self._id)
203202
self._resource_collection_client = storage_client.datasets()
204203

@@ -214,10 +213,26 @@ def name(self) -> str | None:
214213

215214
@override
216215
@classmethod
217-
async def open(cls, *, id: str | None = None, name: str | None = None) -> Dataset:
216+
async def open(
217+
cls,
218+
*,
219+
id: str | None = None,
220+
name: str | None = None,
221+
configuration: Configuration | None = None,
222+
storage_client: BaseStorageClient | None = None,
223+
) -> Dataset:
218224
from crawlee.storages._creation_management import open_storage
219225

220-
return await open_storage(storage_class=cls, id=id, name=name)
226+
configuration = configuration or service_container.get_configuration()
227+
storage_client = storage_client or service_container.get_storage_client()
228+
229+
return await open_storage(
230+
storage_class=cls,
231+
id=id,
232+
name=name,
233+
configuration=configuration,
234+
storage_client=storage_client,
235+
)
221236

222237
@override
223238
async def drop(self) -> None:
@@ -241,7 +256,7 @@ async def push_data(self, data: JsonSerializable, **kwargs: Unpack[PushDataKwarg
241256
# Handle singular items
242257
if not isinstance(data, list):
243258
items = await self.check_and_serialize(data)
244-
return await self._resource_client.push_items(items, **kwargs) # type: ignore[no-any-return] # Mypy is broken
259+
return await self._resource_client.push_items(items, **kwargs)
245260

246261
# Handle lists
247262
payloads_generator = (await self.check_and_serialize(item, index) for index, item in enumerate(data))
@@ -264,7 +279,7 @@ async def get_data(self, **kwargs: Unpack[GetDataKwargs]) -> DatasetItemsListPag
264279
Returns:
265280
List page containing filtered and paginated dataset items.
266281
"""
267-
return await self._resource_client.list_items(**kwargs) # type: ignore[no-any-return] # Mypy is broken
282+
return await self._resource_client.list_items(**kwargs)
268283

269284
async def write_to_csv(self, destination: TextIO, **kwargs: Unpack[ExportDataCsvKwargs]) -> None:
270285
"""Exports the entire dataset into an arbitrary stream.

src/crawlee/storages/_key_value_store.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77

88
from crawlee import service_container
99
from crawlee._utils.docs import docs_group
10-
from crawlee.base_storage_client._models import KeyValueStoreKeyInfo, KeyValueStoreMetadata
10+
from crawlee.base_storage_client import BaseStorageClient, KeyValueStoreKeyInfo, KeyValueStoreMetadata
1111
from crawlee.storages._base_storage import BaseStorage
1212

1313
if TYPE_CHECKING:
1414
from collections.abc import AsyncIterator
1515

16+
from crawlee.configuration import Configuration
17+
1618
T = TypeVar('T')
1719

1820

@@ -51,9 +53,7 @@ class KeyValueStore(BaseStorage):
5153
```
5254
"""
5355

54-
def __init__(self, id: str, name: str | None) -> None:
55-
storage_client = service_container.get_storage_client()
56-
56+
def __init__(self, id: str, name: str | None, storage_client: BaseStorageClient) -> None:
5757
self._id = id
5858
self._name = name
5959

@@ -72,14 +72,30 @@ def name(self) -> str | None:
7272

7373
async def get_info(self) -> KeyValueStoreMetadata | None:
7474
"""Get an object containing general information about the key value store."""
75-
return await self._resource_client.get() # type: ignore[no-any-return] # Mypy is broken
75+
return await self._resource_client.get()
7676

7777
@override
7878
@classmethod
79-
async def open(cls, *, id: str | None = None, name: str | None = None) -> KeyValueStore:
79+
async def open(
80+
cls,
81+
*,
82+
id: str | None = None,
83+
name: str | None = None,
84+
configuration: Configuration | None = None,
85+
storage_client: BaseStorageClient | None = None,
86+
) -> KeyValueStore:
8087
from crawlee.storages._creation_management import open_storage
8188

82-
return await open_storage(storage_class=cls, id=id, name=name)
89+
configuration = configuration or service_container.get_configuration()
90+
storage_client = storage_client or service_container.get_storage_client()
91+
92+
return await open_storage(
93+
storage_class=cls,
94+
id=id,
95+
name=name,
96+
configuration=configuration,
97+
storage_client=storage_client,
98+
)
8399

84100
@override
85101
async def drop(self) -> None:
@@ -142,9 +158,9 @@ async def set_value(
142158
content_type: Content type of the record.
143159
"""
144160
if value is None:
145-
return await self._resource_client.delete_record(key) # type: ignore[no-any-return] # Mypy is broken
161+
return await self._resource_client.delete_record(key)
146162

147-
return await self._resource_client.set_record(key, value, content_type) # type: ignore[no-any-return] # Mypy is broken
163+
return await self._resource_client.set_record(key, value, content_type)
148164

149165
async def get_public_url(self, key: str) -> str:
150166
"""Get the public URL for the given key.
@@ -155,4 +171,4 @@ async def get_public_url(self, key: str) -> str:
155171
Returns:
156172
The public URL for the given key.
157173
"""
158-
return await self._resource_client.get_public_url(key) # type: ignore[no-any-return] # Mypy is broken
174+
return await self._resource_client.get_public_url(key)

src/crawlee/storages/_request_queue.py

Lines changed: 27 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from crawlee._utils.lru_cache import LRUCache
1616
from crawlee._utils.requests import unique_key_to_request_id
1717
from crawlee._utils.wait import wait_for_all_tasks_for_finish
18-
from crawlee.base_storage_client._models import ProcessedRequest, RequestQueueMetadata
18+
from crawlee.base_storage_client import BaseStorageClient, ProcessedRequest, RequestQueueMetadata
1919
from crawlee.events._types import Event
2020
from crawlee.storages._base_storage import BaseStorage
2121
from crawlee.storages._request_provider import RequestProvider
@@ -24,7 +24,7 @@
2424
from collections.abc import Sequence
2525

2626
from crawlee._request import Request
27-
27+
from crawlee.configuration import Configuration
2828

2929
logger = getLogger(__name__)
3030

@@ -104,10 +104,9 @@ class RequestQueue(BaseStorage, RequestProvider):
104104
_STORAGE_CONSISTENCY_DELAY = timedelta(seconds=3)
105105
"""Expected delay for storage to achieve consistency, guiding the timing of subsequent read operations."""
106106

107-
def __init__(self, id: str, name: str | None) -> None:
107+
def __init__(self, id: str, name: str | None, storage_client: BaseStorageClient) -> None:
108108
config = service_container.get_configuration()
109109
event_manager = service_container.get_event_manager()
110-
storage_client = service_container.get_storage_client()
111110

112111
self._id = id
113112
self._name = name
@@ -148,10 +147,26 @@ def name(self) -> str | None:
148147

149148
@override
150149
@classmethod
151-
async def open(cls, *, id: str | None = None, name: str | None = None) -> RequestQueue:
150+
async def open(
151+
cls,
152+
*,
153+
id: str | None = None,
154+
name: str | None = None,
155+
configuration: Configuration | None = None,
156+
storage_client: BaseStorageClient | None = None,
157+
) -> RequestQueue:
152158
from crawlee.storages._creation_management import open_storage
153159

154-
return await open_storage(storage_class=cls, id=id, name=name)
160+
configuration = configuration or service_container.get_configuration()
161+
storage_client = storage_client or service_container.get_storage_client()
162+
163+
return await open_storage(
164+
storage_class=cls,
165+
id=id,
166+
name=name,
167+
configuration=configuration,
168+
storage_client=storage_client,
169+
)
155170

156171
@override
157172
async def drop(self, *, timeout: timedelta | None = None) -> None:
@@ -204,7 +219,7 @@ async def add_request(
204219
):
205220
self._assumed_total_count += 1
206221

207-
return processed_request # type: ignore[no-any-return] # Mypy is broken
222+
return processed_request
208223

209224
@override
210225
async def add_requests_batched(
@@ -260,7 +275,7 @@ async def get_request(self, request_id: str) -> Request | None:
260275
Returns:
261276
The retrieved request, or `None`, if it does not exist.
262277
"""
263-
return await self._resource_client.get_request(request_id) # type: ignore[no-any-return] # Mypy is broken
278+
return await self._resource_client.get_request(request_id)
264279

265280
async def fetch_next_request(self) -> Request | None:
266281
"""Return the next request in the queue to be processed.
@@ -373,7 +388,7 @@ async def mark_request_as_handled(self, request: Request) -> ProcessedRequest |
373388
self._assumed_handled_count += 1
374389

375390
self._cache_request(unique_key_to_request_id(request.unique_key), processed_request)
376-
return processed_request # type: ignore[no-any-return] # Mypy is broken
391+
return processed_request
377392

378393
async def reclaim_request(
379394
self,
@@ -417,7 +432,7 @@ async def reclaim_request(
417432
except Exception as err:
418433
logger.debug(f'Failed to delete request lock for request {request.id}', exc_info=err)
419434

420-
return processed_request # type: ignore[no-any-return] # Mypy is broken
435+
return processed_request
421436

422437
async def is_empty(self) -> bool:
423438
"""Check whether the queue is empty.
@@ -483,7 +498,7 @@ async def is_finished(self) -> bool:
483498

484499
async def get_info(self) -> RequestQueueMetadata | None:
485500
"""Get an object containing general information about the request queue."""
486-
return await self._resource_client.get() # type: ignore[no-any-return] # Mypy is broken
501+
return await self._resource_client.get()
487502

488503
@override
489504
async def get_handled_count(self) -> int:
@@ -658,7 +673,7 @@ async def _prolong_request_lock(self, request_id: str) -> datetime | None:
658673
)
659674
return None
660675
else:
661-
return res.lock_expires_at # type: ignore[no-any-return] # Mypy is broken
676+
return res.lock_expires_at
662677

663678
async def _clear_possible_locks(self) -> None:
664679
self._queue_paused_for_migration = True

0 commit comments

Comments
 (0)