diff --git a/src/crawlee/_utils/recurring_task.py b/src/crawlee/_utils/recurring_task.py index 8ffc71d144..d0c20249e9 100644 --- a/src/crawlee/_utils/recurring_task.py +++ b/src/crawlee/_utils/recurring_task.py @@ -7,6 +7,9 @@ if TYPE_CHECKING: from collections.abc import Callable from datetime import timedelta + from types import TracebackType + + from typing_extensions import Self logger = getLogger(__name__) @@ -26,6 +29,18 @@ def __init__(self, func: Callable, delay: timedelta) -> None: self.delay = delay self.task: asyncio.Task | None = None + async def __aenter__(self) -> Self: + self.start() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + exc_traceback: TracebackType | None, + ) -> None: + await self.stop() + async def _wrapper(self) -> None: """Continuously execute the provided function with the specified delay. diff --git a/src/crawlee/crawlers/_basic/_basic_crawler.py b/src/crawlee/crawlers/_basic/_basic_crawler.py index cdf7e527e2..28544ad607 100644 --- a/src/crawlee/crawlers/_basic/_basic_crawler.py +++ b/src/crawlee/crawlers/_basic/_basic_crawler.py @@ -56,7 +56,7 @@ SessionError, UserDefinedErrorHandlerError, ) -from crawlee.events._types import Event, EventCrawlerStatusData +from crawlee.events._types import Event, EventCrawlerStatusData, EventPersistStateData from crawlee.http_clients import ImpitHttpClient from crawlee.router import Router from crawlee.sessions import SessionPool @@ -437,14 +437,23 @@ def __init__( self._statistics_log_format = statistics_log_format # Statistics - self._statistics = statistics or cast( - 'Statistics[TStatisticsState]', - Statistics.with_default_state( - periodic_message_logger=self._logger, - statistics_log_format=self._statistics_log_format, - log_message='Current request statistics:', - ), - ) + if statistics: + self._statistics = statistics + else: + + async def persist_state_factory() -> KeyValueStore: + return await self.get_key_value_store() + + self._statistics = cast( + 'Statistics[TStatisticsState]', + Statistics.with_default_state( + persistence_enabled=True, + periodic_message_logger=self._logger, + statistics_log_format=self._statistics_log_format, + log_message='Current request statistics:', + persist_state_kvs_factory=persist_state_factory, + ), + ) # Additional context managers to enter and exit self._additional_context_managers = _additional_context_managers or [] @@ -689,7 +698,6 @@ def sigint_handler() -> None: except CancelledError: pass finally: - await self._crawler_state_rec_task.stop() if threading.current_thread() is threading.main_thread(): with suppress(NotImplementedError): asyncio.get_running_loop().remove_signal_handler(signal.SIGINT) @@ -721,8 +729,6 @@ def sigint_handler() -> None: async def _run_crawler(self) -> None: event_manager = self._service_locator.get_event_manager() - self._crawler_state_rec_task.start() - # Collect the context managers to be entered. Context managers that are already active are excluded, # as they were likely entered by the caller, who will also be responsible for exiting them. contexts_to_enter = [ @@ -733,6 +739,7 @@ async def _run_crawler(self) -> None: self._statistics, self._session_pool if self._use_session_pool else None, self._http_client, + self._crawler_state_rec_task, *self._additional_context_managers, ) if cm and getattr(cm, 'active', False) is False @@ -744,6 +751,9 @@ async def _run_crawler(self) -> None: await self._autoscaled_pool.run() + # Emit PERSIST_STATE event when crawler is finishing to allow listeners to persist their state if needed + event_manager.emit(event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=False)) + async def add_requests( self, requests: Sequence[str | Request], diff --git a/src/crawlee/statistics/_statistics.py b/src/crawlee/statistics/_statistics.py index 3e95932c2a..68b4ff6551 100644 --- a/src/crawlee/statistics/_statistics.py +++ b/src/crawlee/statistics/_statistics.py @@ -96,7 +96,7 @@ def __init__( self._state = RecoverableState( default_state=state_model(stats_id=self._id), - persist_state_key=persist_state_key or f'SDK_CRAWLER_STATISTICS_{self._id}', + persist_state_key=persist_state_key or f'__CRAWLER_STATISTICS_{self._id}', persistence_enabled=persistence_enabled, persist_state_kvs_name=persist_state_kvs_name, persist_state_kvs_factory=persist_state_kvs_factory, @@ -130,6 +130,7 @@ def with_default_state( persistence_enabled: bool = False, persist_state_kvs_name: str | None = None, persist_state_key: str | None = None, + persist_state_kvs_factory: Callable[[], Coroutine[None, None, KeyValueStore]] | None = None, log_message: str = 'Statistics', periodic_message_logger: Logger | None = None, log_interval: timedelta = timedelta(minutes=1), @@ -141,6 +142,7 @@ def with_default_state( persistence_enabled=persistence_enabled, persist_state_kvs_name=persist_state_kvs_name, persist_state_key=persist_state_key, + persist_state_kvs_factory=persist_state_kvs_factory, log_message=log_message, periodic_message_logger=periodic_message_logger, log_interval=log_interval, @@ -187,7 +189,10 @@ async def __aexit__( if not self._active: raise RuntimeError(f'The {self.__class__.__name__} is not active.') - self._state.current_value.crawler_finished_at = datetime.now(timezone.utc) + if not self.state.crawler_last_started_at: + raise RuntimeError('Statistics.state.crawler_last_started_at not set.') + self.state.crawler_finished_at = datetime.now(timezone.utc) + self.state.crawler_runtime += self.state.crawler_finished_at - self.state.crawler_last_started_at await self._state.teardown() @@ -255,8 +260,7 @@ def calculate(self) -> FinalStatistics: if self._instance_start is None: raise RuntimeError('The Statistics object is not initialized') - crawler_runtime = datetime.now(timezone.utc) - self._instance_start - total_minutes = crawler_runtime.total_seconds() / 60 + total_minutes = self.state.crawler_runtime.total_seconds() / 60 state = self._state.current_value serialized_state = state.model_dump(by_alias=False) @@ -267,7 +271,7 @@ def calculate(self) -> FinalStatistics: requests_failed_per_minute=math.floor(state.requests_failed / total_minutes) if total_minutes else 0, request_total_duration=state.request_total_finished_duration + state.request_total_failed_duration, requests_total=state.requests_failed + state.requests_finished, - crawler_runtime=crawler_runtime, + crawler_runtime=state.crawler_runtime, requests_finished=state.requests_finished, requests_failed=state.requests_failed, retry_histogram=serialized_state['request_retry_histogram'], diff --git a/tests/unit/crawlers/_basic/test_basic_crawler.py b/tests/unit/crawlers/_basic/test_basic_crawler.py index 09b951ca35..b2b75e50f7 100644 --- a/tests/unit/crawlers/_basic/test_basic_crawler.py +++ b/tests/unit/crawlers/_basic/test_basic_crawler.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +import concurrent import json import logging import os @@ -1643,3 +1644,60 @@ async def handler(context: BasicCrawlingContext) -> None: # Crawler should not fall back to the default storage after the purge assert await unrelated_rq.fetch_next_request() == unrelated_request + + +async def _run_crawler(requests: list[str], storage_dir: str) -> StatisticsState: + """Run crawler and return its statistics state. + + Must be defined like this to be pickable for ProcessPoolExecutor.""" + service_locator.set_configuration( + Configuration( + crawlee_storage_dir=storage_dir, # type: ignore[call-arg] + purge_on_start=False, + ) + ) + + async def request_handler(context: BasicCrawlingContext) -> None: + context.log.info(f'Processing {context.request.url} ...') + + crawler = BasicCrawler( + request_handler=request_handler, + concurrency_settings=ConcurrencySettings(max_concurrency=1, desired_concurrency=1), + ) + + await crawler.run(requests) + return crawler.statistics.state + + +def _process_run_crawler(requests: list[str], storage_dir: str) -> StatisticsState: + return asyncio.run(_run_crawler(requests=requests, storage_dir=storage_dir)) + + +async def test_crawler_statistics_persistence(tmp_path: Path) -> None: + """Test that crawler statistics persist and are loaded correctly. + + This test simulates starting the crawler process twice, and checks that the statistics include first run.""" + + with concurrent.futures.ProcessPoolExecutor() as executor: + # Crawl 2 requests in the first run and automatically persist the state. + first_run_state = executor.submit( + _process_run_crawler, + requests=['https://a.placeholder.com', 'https://b.placeholder.com'], + storage_dir=str(tmp_path), + ).result() + assert first_run_state.requests_finished == 2 + + # Do not reuse the executor to simulate a fresh process to avoid modified class attributes. + with concurrent.futures.ProcessPoolExecutor() as executor: + # Crawl 1 additional requests in the second run, but use previously automatically persisted state. + second_run_state = executor.submit( + _process_run_crawler, requests=['https://c.placeholder.com'], storage_dir=str(tmp_path) + ).result() + assert second_run_state.requests_finished == 3 + + assert first_run_state.crawler_started_at == second_run_state.crawler_started_at + assert first_run_state.crawler_finished_at + assert second_run_state.crawler_finished_at + + assert first_run_state.crawler_finished_at < second_run_state.crawler_finished_at + assert first_run_state.crawler_runtime < second_run_state.crawler_runtime diff --git a/tests/unit/test_configuration.py b/tests/unit/test_configuration.py index f89401e5be..4be4309247 100644 --- a/tests/unit/test_configuration.py +++ b/tests/unit/test_configuration.py @@ -8,6 +8,7 @@ from crawlee import service_locator from crawlee.configuration import Configuration from crawlee.crawlers import HttpCrawler, HttpCrawlingContext +from crawlee.statistics import Statistics from crawlee.storage_clients import MemoryStorageClient from crawlee.storage_clients._file_system._storage_client import FileSystemStorageClient @@ -35,15 +36,40 @@ def test_global_configuration_works_reversed() -> None: ) -async def test_storage_not_persisted_when_disabled(tmp_path: Path, server_url: URL) -> None: +async def test_storage_not_persisted_when_non_persistable_storage_used(tmp_path: Path, server_url: URL) -> None: + """Make the Crawler use MemoryStorageClient which can't persist state.""" + service_locator.set_configuration( + Configuration( + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + ) + crawler = HttpCrawler(storage_client=MemoryStorageClient()) + + @crawler.router.default_handler + async def default_handler(context: HttpCrawlingContext) -> None: + await context.push_data({'url': context.request.url}) + + await crawler.run([str(server_url)]) + + # Verify that no files were created in the storage directory. + content = list(tmp_path.iterdir()) + assert content == [], 'Expected the storage directory to be empty, but it is not.' + + +async def test_storage_persisted_with_explicit_statistics_with_persistable_storage( + tmp_path: Path, server_url: URL +) -> None: + """Make the Crawler use MemoryStorageClient which can't persist state, + but pass explicit statistics to it which will use global FileSystemStorageClient() that can persist state.""" + configuration = Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] ) - storage_client = MemoryStorageClient() + service_locator.set_configuration(configuration) + service_locator.set_storage_client(FileSystemStorageClient()) crawler = HttpCrawler( - configuration=configuration, - storage_client=storage_client, + storage_client=MemoryStorageClient(), statistics=Statistics.with_default_state(persistence_enabled=True) ) @crawler.router.default_handler @@ -52,9 +78,9 @@ async def default_handler(context: HttpCrawlingContext) -> None: await crawler.run([str(server_url)]) - # Verify that no files were created in the storage directory. + # Verify that files were created in the storage directory. content = list(tmp_path.iterdir()) - assert content == [], 'Expected the storage directory to be empty, but it is not.' + assert content != [], 'Expected the storage directory to contain files, but it does not.' async def test_storage_persisted_when_enabled(tmp_path: Path, server_url: URL) -> None: