diff --git a/.gitignore b/.gitignore index 43fc7bc882..c46192cf4f 100644 --- a/.gitignore +++ b/.gitignore @@ -13,6 +13,9 @@ __pycache__ # Poetry poetry.toml +# Other Python tools +.ropeproject + # Mise mise.toml .mise.toml diff --git a/docs/guides/code/request_storage/purge_explicitly_example.py b/docs/guides/code/request_storage/purge_explicitly_example.py index 69d7a9ef97..22e2ae7769 100644 --- a/docs/guides/code/request_storage/purge_explicitly_example.py +++ b/docs/guides/code/request_storage/purge_explicitly_example.py @@ -4,7 +4,7 @@ async def main() -> None: - storage_client = MemoryStorageClient() + storage_client = MemoryStorageClient.from_config() # highlight-next-line await storage_client.purge_on_start() diff --git a/docs/upgrading/upgrading_to_v0x.md b/docs/upgrading/upgrading_to_v0x.md index d5f861cf45..92f9a31c16 100644 --- a/docs/upgrading/upgrading_to_v0x.md +++ b/docs/upgrading/upgrading_to_v0x.md @@ -10,9 +10,18 @@ This page summarizes the breaking changes between Crawlee for Python zero-based This section summarizes the breaking changes between v0.4.x and v0.5.0. ### BeautifulSoupParser + - Renamed `BeautifulSoupParser` to `BeautifulSoupParserType`. Probably used only in type hints. Please replace previous usages of `BeautifulSoupParser` by `BeautifulSoupParserType`. - `BeautifulSoupParser` is now a new class that is used in refactored class `BeautifulSoupCrawler`. +### Service locator + +- The `crawlee.service_container` was completely refactored and renamed to `crawlee.service_locator`. + +### Statistics + +- The `crawlee.statistics.Statistics` class do not accept an event manager as an input argument anymore. It uses the default, global one. + ## Upgrading to v0.4 This section summarizes the breaking changes between v0.3.x and v0.4.0. diff --git a/poetry.lock b/poetry.lock index 83ab86638c..ae49822ed3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "annotated-types" @@ -2123,18 +2123,18 @@ type = ["mypy (>=1.11.2)"] [[package]] name = "playwright" -version = "1.49.0" +version = "1.49.1" description = "A high-level API to automate web browsers" optional = true python-versions = ">=3.9" files = [ - {file = "playwright-1.49.0-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:704532a2d8ba580ec9e1895bfeafddce2e3d52320d4eb8aa38e80376acc5cbb0"}, - {file = "playwright-1.49.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:e453f02c4e5cc2db7e9759c47e7425f32e50ac76c76b7eb17c69eed72f01c4d8"}, - {file = "playwright-1.49.0-py3-none-macosx_11_0_universal2.whl", hash = "sha256:37ae985309184472946a6eb1a237e5d93c9e58a781fa73b75c8751325002a5d4"}, - {file = "playwright-1.49.0-py3-none-manylinux1_x86_64.whl", hash = "sha256:68d94beffb3c9213e3ceaafa66171affd9a5d9162e0c8a3eed1b1132c2e57598"}, - {file = "playwright-1.49.0-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7f12d2aecdb41fc25a624cb15f3e8391c252ebd81985e3d5c1c261fe93779345"}, - {file = "playwright-1.49.0-py3-none-win32.whl", hash = "sha256:91103de52d470594ad375b512d7143fa95d6039111ae11a93eb4fe2f2b4a4858"}, - {file = "playwright-1.49.0-py3-none-win_amd64.whl", hash = "sha256:34d28a2c2d46403368610be4339898dc9c34eb9f7c578207b4715c49743a072a"}, + {file = "playwright-1.49.1-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:1041ffb45a0d0bc44d698d3a5aa3ac4b67c9bd03540da43a0b70616ad52592b8"}, + {file = "playwright-1.49.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:9f38ed3d0c1f4e0a6d1c92e73dd9a61f8855133249d6f0cec28648d38a7137be"}, + {file = "playwright-1.49.1-py3-none-macosx_11_0_universal2.whl", hash = "sha256:3be48c6d26dc819ca0a26567c1ae36a980a0303dcd4249feb6f59e115aaddfb8"}, + {file = "playwright-1.49.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:753ca90ee31b4b03d165cfd36e477309ebf2b4381953f2a982ff612d85b147d2"}, + {file = "playwright-1.49.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cd9bc8dab37aa25198a01f555f0a2e2c3813fe200fef018ac34dfe86b34994b9"}, + {file = "playwright-1.49.1-py3-none-win32.whl", hash = "sha256:43b304be67f096058e587dac453ece550eff87b8fbed28de30f4f022cc1745bb"}, + {file = "playwright-1.49.1-py3-none-win_amd64.whl", hash = "sha256:47b23cb346283278f5b4d1e1990bcb6d6302f80c0aa0ca93dd0601a1400191df"}, ] [package.dependencies] @@ -2877,29 +2877,29 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "ruff" -version = "0.8.2" +version = "0.8.3" description = "An extremely fast Python linter and code formatter, written in Rust." optional = false python-versions = ">=3.7" files = [ - {file = "ruff-0.8.2-py3-none-linux_armv6l.whl", hash = "sha256:c49ab4da37e7c457105aadfd2725e24305ff9bc908487a9bf8d548c6dad8bb3d"}, - {file = "ruff-0.8.2-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:ec016beb69ac16be416c435828be702ee694c0d722505f9c1f35e1b9c0cc1bf5"}, - {file = "ruff-0.8.2-py3-none-macosx_11_0_arm64.whl", hash = "sha256:f05cdf8d050b30e2ba55c9b09330b51f9f97d36d4673213679b965d25a785f3c"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60f578c11feb1d3d257b2fb043ddb47501ab4816e7e221fbb0077f0d5d4e7b6f"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:cbd5cf9b0ae8f30eebc7b360171bd50f59ab29d39f06a670b3e4501a36ba5897"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b402ddee3d777683de60ff76da801fa7e5e8a71038f57ee53e903afbcefdaa58"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:705832cd7d85605cb7858d8a13d75993c8f3ef1397b0831289109e953d833d29"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:32096b41aaf7a5cc095fa45b4167b890e4c8d3fd217603f3634c92a541de7248"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e769083da9439508833cfc7c23e351e1809e67f47c50248250ce1ac52c21fb93"}, - {file = "ruff-0.8.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fe716592ae8a376c2673fdfc1f5c0c193a6d0411f90a496863c99cd9e2ae25d"}, - {file = "ruff-0.8.2-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:81c148825277e737493242b44c5388a300584d73d5774defa9245aaef55448b0"}, - {file = "ruff-0.8.2-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:d261d7850c8367704874847d95febc698a950bf061c9475d4a8b7689adc4f7fa"}, - {file = "ruff-0.8.2-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1ca4e3a87496dc07d2427b7dd7ffa88a1e597c28dad65ae6433ecb9f2e4f022f"}, - {file = "ruff-0.8.2-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:729850feed82ef2440aa27946ab39c18cb4a8889c1128a6d589ffa028ddcfc22"}, - {file = "ruff-0.8.2-py3-none-win32.whl", hash = "sha256:ac42caaa0411d6a7d9594363294416e0e48fc1279e1b0e948391695db2b3d5b1"}, - {file = "ruff-0.8.2-py3-none-win_amd64.whl", hash = "sha256:2aae99ec70abf43372612a838d97bfe77d45146254568d94926e8ed5bbb409ea"}, - {file = "ruff-0.8.2-py3-none-win_arm64.whl", hash = "sha256:fb88e2a506b70cfbc2de6fae6681c4f944f7dd5f2fe87233a7233d888bad73e8"}, - {file = "ruff-0.8.2.tar.gz", hash = "sha256:b84f4f414dda8ac7f75075c1fa0b905ac0ff25361f42e6d5da681a465e0f78e5"}, + {file = "ruff-0.8.3-py3-none-linux_armv6l.whl", hash = "sha256:8d5d273ffffff0acd3db5bf626d4b131aa5a5ada1276126231c4174543ce20d6"}, + {file = "ruff-0.8.3-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e4d66a21de39f15c9757d00c50c8cdd20ac84f55684ca56def7891a025d7e939"}, + {file = "ruff-0.8.3-py3-none-macosx_11_0_arm64.whl", hash = "sha256:c356e770811858bd20832af696ff6c7e884701115094f427b64b25093d6d932d"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c0a60a825e3e177116c84009d5ebaa90cf40dfab56e1358d1df4e29a9a14b13"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:75fb782f4db39501210ac093c79c3de581d306624575eddd7e4e13747e61ba18"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f26bc76a133ecb09a38b7868737eded6941b70a6d34ef53a4027e83913b6502"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:01b14b2f72a37390c1b13477c1c02d53184f728be2f3ffc3ace5b44e9e87b90d"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:53babd6e63e31f4e96ec95ea0d962298f9f0d9cc5990a1bbb023a6baf2503a82"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1ae441ce4cf925b7f363d33cd6570c51435972d697e3e58928973994e56e1452"}, + {file = "ruff-0.8.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d7c65bc0cadce32255e93c57d57ecc2cca23149edd52714c0c5d6fa11ec328cd"}, + {file = "ruff-0.8.3-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5be450bb18f23f0edc5a4e5585c17a56ba88920d598f04a06bd9fd76d324cb20"}, + {file = "ruff-0.8.3-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8faeae3827eaa77f5721f09b9472a18c749139c891dbc17f45e72d8f2ca1f8fc"}, + {file = "ruff-0.8.3-py3-none-musllinux_1_2_i686.whl", hash = "sha256:db503486e1cf074b9808403991663e4277f5c664d3fe237ee0d994d1305bb060"}, + {file = "ruff-0.8.3-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:6567be9fb62fbd7a099209257fef4ad2c3153b60579818b31a23c886ed4147ea"}, + {file = "ruff-0.8.3-py3-none-win32.whl", hash = "sha256:19048f2f878f3ee4583fc6cb23fb636e48c2635e30fb2022b3a1cd293402f964"}, + {file = "ruff-0.8.3-py3-none-win_amd64.whl", hash = "sha256:f7df94f57d7418fa7c3ffb650757e0c2b96cf2501a0b192c18e4fb5571dfada9"}, + {file = "ruff-0.8.3-py3-none-win_arm64.whl", hash = "sha256:fe2756edf68ea79707c8d68b78ca9a58ed9af22e430430491ee03e718b5e4936"}, + {file = "ruff-0.8.3.tar.gz", hash = "sha256:5e7558304353b84279042fc584a4f4cb8a07ae79b2bf3da1a7551d960b5626d3"}, ] [[package]] @@ -3670,4 +3670,4 @@ playwright = ["playwright"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "2ef3de5474613439f8aad7c8dcb3e553a0d9250e2f192e14a310684c966fa5cf" +content-hash = "ba84ac0e96fc33b777b9c097e3e905b6493fdf1288cb38d56ff775984b9817e3" diff --git a/pyproject.toml b/pyproject.toml index 2648144e5f..44ffef1327 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,13 +60,15 @@ parsel = { version = ">=1.9.0", optional = true } playwright = { version = ">=1.27.0", optional = true } psutil = ">=6.0.0" pydantic = ">=2.8.1, !=2.10.0, !=2.10.1, !=2.10.2" -pydantic-settings = ">=2.2.0" +# TODO: relax the upper bound once the issue is resolved: +# https://github.com/apify/crawlee-python/issues/814 +pydantic-settings = ">=2.2.0 <2.7.0" pyee = ">=9.0.0" sortedcollections = ">=2.1.0" tldextract = ">=5.1.0" typer = ">=0.12.0" typing-extensions = ">=4.1.0" -yarl = "^1.18.0" +yarl = ">=1.18.0" [tool.poetry.group.dev.dependencies] build = "~1.2.0" @@ -206,9 +208,9 @@ warn_unused_ignores = true [[tool.mypy.overrides]] # Example codes are sometimes showing integration of crawlee with external tool, which is not dependency of crawlee. -module =[ - "apify", # Example code shows integration of apify and crawlee. - "camoufox" # Example code shows integration of camoufox and crawlee. +module = [ + "apify", # Example code shows integration of apify and crawlee. + "camoufox", # Example code shows integration of camoufox and crawlee. ] ignore_missing_imports = true diff --git a/src/crawlee/__init__.py b/src/crawlee/__init__.py index 11ec50612c..6b7081dc50 100644 --- a/src/crawlee/__init__.py +++ b/src/crawlee/__init__.py @@ -1,9 +1,10 @@ from importlib import metadata from ._request import Request +from ._service_locator import service_locator from ._types import ConcurrencySettings, EnqueueStrategy, HttpHeaders from ._utils.globs import Glob __version__ = metadata.version('crawlee') -__all__ = ['ConcurrencySettings', 'EnqueueStrategy', 'Glob', 'HttpHeaders', 'Request'] +__all__ = ['ConcurrencySettings', 'EnqueueStrategy', 'Glob', 'HttpHeaders', 'Request', 'service_locator'] diff --git a/src/crawlee/_autoscaling/snapshotter.py b/src/crawlee/_autoscaling/snapshotter.py index 9e95254353..62e2c4d2df 100644 --- a/src/crawlee/_autoscaling/snapshotter.py +++ b/src/crawlee/_autoscaling/snapshotter.py @@ -10,13 +10,8 @@ import psutil from sortedcontainers import SortedList -from crawlee._autoscaling.types import ( - ClientSnapshot, - CpuSnapshot, - EventLoopSnapshot, - MemorySnapshot, - Snapshot, -) +from crawlee import service_locator +from crawlee._autoscaling.types import ClientSnapshot, CpuSnapshot, EventLoopSnapshot, MemorySnapshot, Snapshot from crawlee._utils.byte_size import ByteSize from crawlee._utils.context import ensure_context from crawlee._utils.docs import docs_group @@ -26,8 +21,6 @@ if TYPE_CHECKING: from types import TracebackType - from crawlee.events import EventManager - logger = getLogger(__name__) T = TypeVar('T') @@ -45,7 +38,6 @@ class Snapshotter: def __init__( self, - event_manager: EventManager, *, event_loop_snapshot_interval: timedelta = timedelta(milliseconds=500), client_snapshot_interval: timedelta = timedelta(milliseconds=1000), @@ -63,8 +55,6 @@ def __init__( """A default constructor. Args: - event_manager: The event manager used to emit system info events. From data provided by this event - the CPU and memory usage are read. event_loop_snapshot_interval: The interval at which the event loop is sampled. client_snapshot_interval: The interval at which the client is sampled. max_used_cpu_ratio: Sets the ratio, defining the maximum CPU usage. When the CPU usage is higher than @@ -90,7 +80,6 @@ def __init__( if available_memory_ratio is None and max_memory_size is None: raise ValueError('At least one of `available_memory_ratio` or `max_memory_size` must be specified') - self._event_manager = event_manager self._event_loop_snapshot_interval = event_loop_snapshot_interval self._client_snapshot_interval = client_snapshot_interval self._max_event_loop_delay = max_event_loop_delay @@ -145,8 +134,9 @@ async def __aenter__(self) -> Snapshotter: raise RuntimeError(f'The {self.__class__.__name__} is already active.') self._active = True - self._event_manager.on(event=Event.SYSTEM_INFO, listener=self._snapshot_cpu) - self._event_manager.on(event=Event.SYSTEM_INFO, listener=self._snapshot_memory) + event_manager = service_locator.get_event_manager() + event_manager.on(event=Event.SYSTEM_INFO, listener=self._snapshot_cpu) + event_manager.on(event=Event.SYSTEM_INFO, listener=self._snapshot_memory) self._snapshot_event_loop_task.start() self._snapshot_client_task.start() return self @@ -168,8 +158,9 @@ async def __aexit__( if not self._active: raise RuntimeError(f'The {self.__class__.__name__} is not active.') - self._event_manager.off(event=Event.SYSTEM_INFO, listener=self._snapshot_cpu) - self._event_manager.off(event=Event.SYSTEM_INFO, listener=self._snapshot_memory) + event_manager = service_locator.get_event_manager() + event_manager.off(event=Event.SYSTEM_INFO, listener=self._snapshot_cpu) + event_manager.off(event=Event.SYSTEM_INFO, listener=self._snapshot_memory) await self._snapshot_event_loop_task.stop() await self._snapshot_client_task.stop() self._active = False diff --git a/src/crawlee/_log_config.py b/src/crawlee/_log_config.py index 762a4e6e28..a77d9409a9 100644 --- a/src/crawlee/_log_config.py +++ b/src/crawlee/_log_config.py @@ -4,13 +4,12 @@ import logging import sys import textwrap -from typing import TYPE_CHECKING, Any +from typing import Any from colorama import Fore, Style, just_fix_windows_console from typing_extensions import assert_never -if TYPE_CHECKING: - from crawlee.configuration import Configuration +from crawlee import service_locator just_fix_windows_console() @@ -35,22 +34,24 @@ _LOG_MESSAGE_INDENT = ' ' * 6 -def get_configured_log_level(configuration: Configuration) -> int: - verbose_logging_requested = 'verbose_log' in configuration.model_fields_set and configuration.verbose_log +def get_configured_log_level() -> int: + config = service_locator.get_configuration() - if 'log_level' in configuration.model_fields_set: - if configuration.log_level == 'DEBUG': + verbose_logging_requested = 'verbose_log' in config.model_fields_set and config.verbose_log + + if 'log_level' in config.model_fields_set: + if config.log_level == 'DEBUG': return logging.DEBUG - if configuration.log_level == 'INFO': + if config.log_level == 'INFO': return logging.INFO - if configuration.log_level == 'WARNING': + if config.log_level == 'WARNING': return logging.WARNING - if configuration.log_level == 'ERROR': + if config.log_level == 'ERROR': return logging.ERROR - if configuration.log_level == 'CRITICAL': + if config.log_level == 'CRITICAL': return logging.CRITICAL - assert_never(configuration.log_level) + assert_never(config.log_level) if sys.flags.dev_mode or verbose_logging_requested: return logging.DEBUG @@ -58,12 +59,7 @@ def get_configured_log_level(configuration: Configuration) -> int: return logging.INFO -def configure_logger( - logger: logging.Logger, - configuration: Configuration, - *, - remove_old_handlers: bool = False, -) -> None: +def configure_logger(logger: logging.Logger, *, remove_old_handlers: bool = False) -> None: handler = logging.StreamHandler() handler.setFormatter(CrawleeLogFormatter()) @@ -72,7 +68,7 @@ def configure_logger( logger.removeHandler(old_handler) logger.addHandler(handler) - logger.setLevel(get_configured_log_level(configuration)) + logger.setLevel(get_configured_log_level()) class CrawleeLogFormatter(logging.Formatter): diff --git a/src/crawlee/_service_locator.py b/src/crawlee/_service_locator.py new file mode 100644 index 0000000000..2cc8ea2d9f --- /dev/null +++ b/src/crawlee/_service_locator.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from crawlee._utils.docs import docs_group +from crawlee.base_storage_client._base_storage_client import BaseStorageClient +from crawlee.configuration import Configuration +from crawlee.errors import ServiceConflictError +from crawlee.events._event_manager import EventManager + + +@docs_group('Classes') +class ServiceLocator: + """Service locator for managing the services used by Crawlee. + + All services are initialized to its default value lazily. + """ + + def __init__(self) -> None: + self._configuration: Configuration | None = None + self._event_manager: EventManager | None = None + self._storage_client: BaseStorageClient | None = None + + # Flags to check if the services were already set. + self._configuration_was_set = False + self._event_manager_was_set = False + self._storage_client_was_set = False + + def get_configuration(self) -> Configuration: + """Get the configuration.""" + if self._configuration is None: + self._configuration = Configuration() + + return self._configuration + + def set_configuration(self, configuration: Configuration) -> None: + """Set the configuration. + + Args: + configuration: The configuration to set. + + Raises: + ServiceConflictError: If the configuration was already set. + """ + if self._configuration_was_set: + raise ServiceConflictError(Configuration, configuration, self._configuration) + + self._configuration = configuration + self._configuration_was_set = True + + def get_event_manager(self) -> EventManager: + """Get the event manager.""" + if self._event_manager is None: + from crawlee.events import LocalEventManager + + self._event_manager = LocalEventManager() + + return self._event_manager + + def set_event_manager(self, event_manager: EventManager) -> None: + """Set the event manager. + + Args: + event_manager: The event manager to set. + + Raises: + ServiceConflictError: If the event manager was already set. + """ + if self._event_manager_was_set: + raise ServiceConflictError(EventManager, event_manager, self._event_manager) + + self._event_manager = event_manager + self._event_manager_was_set = True + + def get_storage_client(self) -> BaseStorageClient: + """Get the storage client.""" + if self._storage_client is None: + from crawlee.memory_storage_client import MemoryStorageClient + + self._storage_client = MemoryStorageClient.from_config() + + return self._storage_client + + def set_storage_client(self, storage_client: BaseStorageClient) -> None: + """Set the storage client. + + Args: + storage_client: The storage client to set. + + Raises: + ServiceConflictError: If the storage client was already set. + """ + if self._storage_client_was_set: + raise ServiceConflictError(BaseStorageClient, storage_client, self._storage_client) + + self._storage_client = storage_client + self._storage_client_was_set = True + + +service_locator = ServiceLocator() diff --git a/src/crawlee/basic_crawler/_basic_crawler.py b/src/crawlee/basic_crawler/_basic_crawler.py index aeec0031e9..8d971681df 100644 --- a/src/crawlee/basic_crawler/_basic_crawler.py +++ b/src/crawlee/basic_crawler/_basic_crawler.py @@ -18,7 +18,7 @@ from tldextract import TLDExtract from typing_extensions import NotRequired, TypedDict, TypeVar, Unpack, assert_never -from crawlee import EnqueueStrategy, Glob, service_container +from crawlee import EnqueueStrategy, Glob, service_locator from crawlee._autoscaling import AutoscaledPool from crawlee._autoscaling.snapshotter import Snapshotter from crawlee._autoscaling.system_status import SystemStatus @@ -50,6 +50,7 @@ from contextlib import AbstractAsyncContextManager from crawlee._types import ConcurrencySettings, HttpMethod, JsonSerializable + from crawlee.base_storage_client import BaseStorageClient from crawlee.base_storage_client._models import DatasetItemsListPage from crawlee.configuration import Configuration from crawlee.events._event_manager import EventManager @@ -72,17 +73,29 @@ class BasicCrawlerOptions(TypedDict, Generic[TCrawlingContext]): It is intended for typing forwarded `__init__` arguments in the subclasses. """ + configuration: NotRequired[Configuration] + """The configuration object. Some of its properties are used as defaults for the crawler.""" + + event_manager: NotRequired[EventManager] + """The event manager for managing events for the crawler and all its components.""" + + storage_client: NotRequired[BaseStorageClient] + """The storage client for managing storages for the crawler and all its components.""" + request_provider: NotRequired[RequestProvider] """Provider for requests to be processed by the crawler.""" - request_handler: NotRequired[Callable[[TCrawlingContext], Awaitable[None]]] - """A callable responsible for handling requests.""" + session_pool: NotRequired[SessionPool] + """A custom `SessionPool` instance, allowing the use of non-default configuration.""" + + proxy_configuration: NotRequired[ProxyConfiguration] + """HTTP proxy configuration used when making requests.""" http_client: NotRequired[BaseHttpClient] - """HTTP client used by `BasicCrawlingContext.send_request` and the HTTP-based crawling.""" + """HTTP client used by `BasicCrawlingContext.send_request` method.""" - concurrency_settings: NotRequired[ConcurrencySettings] - """Settings to fine-tune concurrency levels.""" + request_handler: NotRequired[Callable[[TCrawlingContext], Awaitable[None]]] + """A callable responsible for handling requests.""" max_request_retries: NotRequired[int] """Maximum number of attempts to process a single request.""" @@ -96,49 +109,45 @@ class BasicCrawlerOptions(TypedDict, Generic[TCrawlingContext]): """Maximum number of session rotations per request. The crawler rotates the session if a proxy error occurs or if the website blocks the request.""" - configuration: NotRequired[Configuration] - """Crawler configuration.""" - - request_handler_timeout: NotRequired[timedelta] - """Maximum duration allowed for a single request handler to run.""" + max_crawl_depth: NotRequired[int | None] + """Specifies the maximum crawl depth. If set, the crawler will stop processing links beyond this depth. + The crawl depth starts at 0 for initial requests and increases with each subsequent level of links. + Requests at the maximum depth will still be processed, but no new links will be enqueued from those requests. + If not set, crawling continues without depth restrictions. + """ use_session_pool: NotRequired[bool] """Enable the use of a session pool for managing sessions during crawling.""" - session_pool: NotRequired[SessionPool] - """A custom `SessionPool` instance, allowing the use of non-default configuration.""" - retry_on_blocked: NotRequired[bool] """If True, the crawler attempts to bypass bot protections automatically.""" - proxy_configuration: NotRequired[ProxyConfiguration] - """HTTP proxy configuration used when making requests.""" + concurrency_settings: NotRequired[ConcurrencySettings] + """Settings to fine-tune concurrency levels.""" + + request_handler_timeout: NotRequired[timedelta] + """Maximum duration allowed for a single request handler to run.""" statistics: NotRequired[Statistics[StatisticsState]] """A custom `Statistics` instance, allowing the use of non-default configuration.""" - event_manager: NotRequired[EventManager] - """A custom `EventManager` instance, allowing the use of non-default configuration.""" + abort_on_error: NotRequired[bool] + """If True, the crawler stops immediately when any request handler error occurs.""" configure_logging: NotRequired[bool] """If True, the crawler will set up logging infrastructure automatically.""" - max_crawl_depth: NotRequired[int | None] - """Limits crawl depth from 0 (initial requests) up to the specified `max_crawl_depth`. - Requests at the maximum depth are processed, but no further links are enqueued.""" - - abort_on_error: NotRequired[bool] - """If True, the crawler stops immediately when any request handler error occurs.""" - _context_pipeline: NotRequired[ContextPipeline[TCrawlingContext]] """Enables extending the request lifecycle and modifying the crawling context. Intended for use by subclasses rather than direct instantiation of `BasicCrawler`.""" _additional_context_managers: NotRequired[Sequence[AbstractAsyncContextManager]] - """Additional context managers used throughout the crawler lifecycle.""" + """Additional context managers used throughout the crawler lifecycle. Intended for use by + subclasses rather than direct instantiation of `BasicCrawler`.""" _logger: NotRequired[logging.Logger] - """A logger instance, typically provided by a subclass, for consistent logging labels.""" + """A logger instance, typically provided by a subclass, for consistent logging labels. Intended for use by + subclasses rather than direct instantiation of `BasicCrawler`.""" @docs_group('Classes') @@ -169,24 +178,25 @@ class BasicCrawler(Generic[TCrawlingContext]): def __init__( self, *, + configuration: Configuration | None = None, + event_manager: EventManager | None = None, + storage_client: BaseStorageClient | None = None, request_provider: RequestProvider | None = None, - request_handler: Callable[[TCrawlingContext], Awaitable[None]] | None = None, + session_pool: SessionPool | None = None, + proxy_configuration: ProxyConfiguration | None = None, http_client: BaseHttpClient | None = None, - concurrency_settings: ConcurrencySettings | None = None, + request_handler: Callable[[TCrawlingContext], Awaitable[None]] | None = None, max_request_retries: int = 3, max_requests_per_crawl: int | None = None, max_session_rotations: int = 10, - configuration: Configuration | None = None, - request_handler_timeout: timedelta = timedelta(minutes=1), - session_pool: SessionPool | None = None, + max_crawl_depth: int | None = None, use_session_pool: bool = True, retry_on_blocked: bool = True, - proxy_configuration: ProxyConfiguration | None = None, + concurrency_settings: ConcurrencySettings | None = None, + request_handler_timeout: timedelta = timedelta(minutes=1), statistics: Statistics | None = None, - event_manager: EventManager | None = None, - configure_logging: bool = True, - max_crawl_depth: int | None = None, abort_on_error: bool = False, + configure_logging: bool = True, _context_pipeline: ContextPipeline[TCrawlingContext] | None = None, _additional_context_managers: Sequence[AbstractAsyncContextManager] | None = None, _logger: logging.Logger | None = None, @@ -194,10 +204,14 @@ def __init__( """A default constructor. Args: + configuration: The configuration object. Some of its properties are used as defaults for the crawler. + event_manager: The event manager for managing events for the crawler and all its components. + storage_client: The storage client for managing storages for the crawler and all its components. request_provider: Provider for requests to be processed by the crawler. + session_pool: A custom `SessionPool` instance, allowing the use of non-default configuration. + proxy_configuration: HTTP proxy configuration used when making requests. + http_client: HTTP client used by `BasicCrawlingContext.send_request` method. request_handler: A callable responsible for handling requests. - http_client: HTTP client used by `BasicCrawlingContext.send_request` and the HTTP-based crawling. - concurrency_settings: Settings to fine-tune concurrency levels. max_request_retries: Maximum number of attempts to process a single request. max_requests_per_crawl: Maximum number of pages to open during a crawl. The crawl stops upon reaching this limit. Setting this value can help avoid infinite loops in misconfigured crawlers. `None` means @@ -205,103 +219,109 @@ def __init__( this value. max_session_rotations: Maximum number of session rotations per request. The crawler rotates the session if a proxy error occurs or if the website blocks the request. - configuration: Crawler configuration. - request_handler_timeout: Maximum duration allowed for a single request handler to run. + max_crawl_depth: Specifies the maximum crawl depth. If set, the crawler will stop processing links beyond + this depth. The crawl depth starts at 0 for initial requests and increases with each subsequent level + of links. Requests at the maximum depth will still be processed, but no new links will be enqueued + from those requests. If not set, crawling continues without depth restrictions. use_session_pool: Enable the use of a session pool for managing sessions during crawling. - session_pool: A custom `SessionPool` instance, allowing the use of non-default configuration. retry_on_blocked: If True, the crawler attempts to bypass bot protections automatically. - proxy_configuration: HTTP proxy configuration used when making requests. + concurrency_settings: Settings to fine-tune concurrency levels. + request_handler_timeout: Maximum duration allowed for a single request handler to run. statistics: A custom `Statistics` instance, allowing the use of non-default configuration. - event_manager: A custom `EventManager` instance, allowing the use of non-default configuration. - configure_logging: If True, the crawler will set up logging infrastructure automatically. - max_crawl_depth: Maximum crawl depth. If set, the crawler will stop crawling after reaching this depth. abort_on_error: If True, the crawler stops immediately when any request handler error occurs. + configure_logging: If True, the crawler will set up logging infrastructure automatically. _context_pipeline: Enables extending the request lifecycle and modifying the crawling context. Intended for use by subclasses rather than direct instantiation of `BasicCrawler`. _additional_context_managers: Additional context managers used throughout the crawler lifecycle. + Intended for use by subclasses rather than direct instantiation of `BasicCrawler`. _logger: A logger instance, typically provided by a subclass, for consistent logging labels. + Intended for use by subclasses rather than direct instantiation of `BasicCrawler`. """ - self._router: Router[TCrawlingContext] | None = None + if configuration: + service_locator.set_configuration(configuration) + if storage_client: + service_locator.set_storage_client(storage_client) + if event_manager: + service_locator.set_event_manager(event_manager) + + config = service_locator.get_configuration() + # Core components + self._request_provider = request_provider + self._session_pool = session_pool or SessionPool() + self._proxy_configuration = proxy_configuration + self._http_client = http_client or HttpxHttpClient() + + # Request router setup + self._router: Router[TCrawlingContext] | None = None if isinstance(cast(Router, request_handler), Router): self._router = cast(Router[TCrawlingContext], request_handler) elif request_handler is not None: self._router = None self.router.default_handler(request_handler) - self._http_client = http_client or HttpxHttpClient() - - self._context_pipeline = (_context_pipeline or ContextPipeline()).compose(self._check_url_after_redirects) - + # Error & failed request handlers self._error_handler: ErrorHandler[TCrawlingContext | BasicCrawlingContext] | None = None self._failed_request_handler: FailedRequestHandler[TCrawlingContext | BasicCrawlingContext] | None = None + self._abort_on_error = abort_on_error + # Context pipeline + self._context_pipeline = (_context_pipeline or ContextPipeline()).compose(self._check_url_after_redirects) + + # Crawl settings self._max_request_retries = max_request_retries self._max_requests_per_crawl = max_requests_per_crawl self._max_session_rotations = max_session_rotations + self._max_crawl_depth = max_crawl_depth - self._request_provider = request_provider - self._configuration = configuration or service_container.get_configuration() - + # Timeouts self._request_handler_timeout = request_handler_timeout self._internal_timeout = ( - self._configuration.internal_timeout - if self._configuration.internal_timeout is not None + config.internal_timeout + if config.internal_timeout is not None else max(2 * request_handler_timeout, timedelta(minutes=5)) ) - self._tld_extractor = TLDExtract(cache_dir=tempfile.TemporaryDirectory().name) - - self._event_manager = event_manager or service_container.get_event_manager() - self._snapshotter = Snapshotter( - self._event_manager, - max_memory_size=ByteSize.from_mb(self._configuration.memory_mbytes) - if self._configuration.memory_mbytes - else None, - available_memory_ratio=self._configuration.available_memory_ratio, - ) - self._autoscaled_pool = AutoscaledPool( - system_status=SystemStatus(self._snapshotter), - is_finished_function=self.__is_finished_function, - is_task_ready_function=self.__is_task_ready_function, - run_task_function=self.__run_task_function, - concurrency_settings=concurrency_settings, - ) - + # Retry and session settings self._use_session_pool = use_session_pool - self._session_pool = session_pool or SessionPool() - self._retry_on_blocked = retry_on_blocked + # Logging setup if configure_logging: root_logger = logging.getLogger() - configure_logger(root_logger, self._configuration, remove_old_handlers=True) - - # Silence HTTPX logger - httpx_logger = logging.getLogger('httpx') - httpx_logger.setLevel( - logging.DEBUG if get_configured_log_level(self._configuration) <= logging.DEBUG else logging.WARNING - ) - - if not _logger: - _logger = logging.getLogger(__name__) + configure_logger(root_logger, remove_old_handlers=True) + httpx_logger = logging.getLogger('httpx') # Silence HTTPX logger + httpx_logger.setLevel(logging.DEBUG if get_configured_log_level() <= logging.DEBUG else logging.WARNING) + self._logger = _logger or logging.getLogger(__name__) - self._logger = _logger - - self._proxy_configuration = proxy_configuration + # Statistics self._statistics = statistics or Statistics( - event_manager=self._event_manager, periodic_message_logger=self._logger, log_message='Current request statistics:', ) + + # Additional context managers to enter and exit self._additional_context_managers = _additional_context_managers or [] + # Internal, not explicitly configurable components + self._tld_extractor = TLDExtract(cache_dir=tempfile.TemporaryDirectory().name) + self._snapshotter = Snapshotter( + max_memory_size=ByteSize.from_mb(config.memory_mbytes) if config.memory_mbytes else None, + available_memory_ratio=config.available_memory_ratio, + ) + self._autoscaled_pool = AutoscaledPool( + system_status=SystemStatus(self._snapshotter), + is_finished_function=self.__is_finished_function, + is_task_ready_function=self.__is_task_ready_function, + run_task_function=self.__run_task_function, + concurrency_settings=concurrency_settings, + ) + + # State flags self._running = False self._has_finished_before = False - self._max_crawl_depth = max_crawl_depth self._failed = False - self._abort_on_error = abort_on_error @property def log(self) -> logging.Logger: @@ -369,7 +389,7 @@ async def get_request_provider( ) -> RequestProvider: """Return the configured request provider. If none is configured, open and return the default request queue.""" if not self._request_provider: - self._request_provider = await RequestQueue.open(id=id, name=name, configuration=self._configuration) + self._request_provider = await RequestQueue.open(id=id, name=name) return self._request_provider @@ -380,7 +400,7 @@ async def get_dataset( name: str | None = None, ) -> Dataset: """Return the dataset with the given ID or name. If none is provided, return the default dataset.""" - return await Dataset.open(id=id, name=name, configuration=self._configuration) + return await Dataset.open(id=id, name=name) async def get_key_value_store( self, @@ -389,7 +409,7 @@ async def get_key_value_store( name: str | None = None, ) -> KeyValueStore: """Return the key-value store with the given ID or name. If none is provided, return the default KVS.""" - return await KeyValueStore.open(id=id, name=name, configuration=self._configuration) + return await KeyValueStore.open(id=id, name=name) def error_handler( self, handler: ErrorHandler[TCrawlingContext | BasicCrawlingContext] @@ -434,7 +454,7 @@ async def run( request_provider = await self.get_request_provider() if purge_request_queue and isinstance(request_provider, RequestQueue): await request_provider.drop() - self._request_provider = await RequestQueue.open(configuration=self._configuration) + self._request_provider = await RequestQueue.open() if requests is not None: await self.add_requests(requests) @@ -486,12 +506,14 @@ def sigint_handler() -> None: return final_statistics async def _run_crawler(self) -> None: + event_manager = service_locator.get_event_manager() + # Collect the context managers to be entered. Context managers that are already active are excluded, # as they were likely entered by the caller, who will also be responsible for exiting them. contexts_to_enter = [ cm for cm in ( - self._event_manager, + event_manager, self._snapshotter, self._statistics, self._session_pool if self._use_session_pool else None, @@ -502,7 +524,7 @@ async def _run_crawler(self) -> None: async with AsyncExitStack() as exit_stack: for context in contexts_to_enter: - await exit_stack.enter_async_context(context) + await exit_stack.enter_async_context(context) # type: ignore[arg-type] await self._autoscaled_pool.run() diff --git a/src/crawlee/configuration.py b/src/crawlee/configuration.py index c640c9c7e5..fdd14c1953 100644 --- a/src/crawlee/configuration.py +++ b/src/crawlee/configuration.py @@ -233,17 +233,16 @@ class Configuration(BaseSettings): @classmethod def get_global_configuration(cls) -> Self: - """Retrieve the global instance of the configuration.""" - from crawlee import service_container + """Retrieve the global instance of the configuration. - if service_container.get_configuration_if_set() is None: - service_container.set_configuration(cls()) + Mostly for the backwards compatibility. It is recommended to use the `service_locator.get_configuration()` + instead. + """ + from crawlee import service_locator - global_instance = service_container.get_configuration() + config = service_locator.get_configuration() - if not isinstance(global_instance, cls): - raise TypeError( - f'Requested global configuration object of type {cls}, but {global_instance.__class__} was found' - ) + if not isinstance(config, cls): + raise TypeError(f'Requested global configuration object of type {cls}, but {config.__class__} was found') - return global_instance + return config diff --git a/src/crawlee/errors.py b/src/crawlee/errors.py index ff6348842e..40f61c20a8 100644 --- a/src/crawlee/errors.py +++ b/src/crawlee/errors.py @@ -35,6 +35,17 @@ class SessionError(Exception): """ +@docs_group('Errors') +class ServiceConflictError(Exception): + """Raised when attempting to reassign a service in service container that was already configured.""" + + def __init__(self, service: type, new_value: object, existing_value: object) -> None: + super().__init__( + f'Service {service.__name__} has already been set. Existing value: {existing_value}, ' + f'attempted new value: {new_value}.' + ) + + @docs_group('Errors') class ProxyError(SessionError): """Raised when a proxy is being blocked or malfunctions.""" @@ -89,13 +100,3 @@ def __init__(self, wrapped_exception: Exception, crawling_context: BasicCrawling @docs_group('Errors') class ContextPipelineInterruptedError(Exception): """May be thrown in the initialization phase of a middleware to signal that the request should not be processed.""" - - -@docs_group('Errors') -class ServiceConflictError(RuntimeError): - """Thrown when a service container is getting reconfigured.""" - - def __init__(self, service_name: str, new_value: object, old_value: object) -> None: - super().__init__( - f"Service '{service_name}' was already set (existing value is '{old_value}', new value is '{new_value}')." - ) diff --git a/src/crawlee/events/_event_manager.py b/src/crawlee/events/_event_manager.py index ad2f3e82ab..b7fb461619 100644 --- a/src/crawlee/events/_event_manager.py +++ b/src/crawlee/events/_event_manager.py @@ -149,9 +149,9 @@ async def listener_wrapper(event_data: EventData) -> None: self._listener_tasks.add(listener_task) try: - logger.debug('LocalEventManager.on.listener_wrapper(): Awaiting listener task...') + logger.debug('EventManager.on.listener_wrapper(): Awaiting listener task...') await listener_task - logger.debug('LocalEventManager.on.listener_wrapper(): Listener task completed.') + logger.debug('EventManager.on.listener_wrapper(): Listener task completed.') except Exception: # We need to swallow the exception and just log it here, otherwise it could break the event emitter logger.exception( @@ -159,7 +159,7 @@ async def listener_wrapper(event_data: EventData) -> None: extra={'event_name': event.value, 'listener_name': listener.__name__}, ) finally: - logger.debug('LocalEventManager.on.listener_wrapper(): Removing listener task from the set...') + logger.debug('EventManager.on.listener_wrapper(): Removing listener task from the set...') self._listener_tasks.remove(listener_task) self._listeners_to_wrappers[event][listener].append(listener_wrapper) diff --git a/src/crawlee/memory_storage_client/_creation_management.py b/src/crawlee/memory_storage_client/_creation_management.py index 093ca7a0ea..f1581f4d2f 100644 --- a/src/crawlee/memory_storage_client/_creation_management.py +++ b/src/crawlee/memory_storage_client/_creation_management.py @@ -21,9 +21,6 @@ Request, RequestQueueMetadata, ) -from crawlee.storages._dataset import Dataset -from crawlee.storages._key_value_store import KeyValueStore -from crawlee.storages._request_queue import RequestQueue if TYPE_CHECKING: from crawlee.memory_storage_client._dataset_client import DatasetClient @@ -400,24 +397,8 @@ def _determine_storage_path( id: str | None = None, name: str | None = None, ) -> str | None: - from crawlee.memory_storage_client._dataset_client import DatasetClient - from crawlee.memory_storage_client._key_value_store_client import KeyValueStoreClient - from crawlee.memory_storage_client._request_queue_client import RequestQueueClient - from crawlee.storages._creation_management import _get_default_storage_id - - configuration = memory_storage_client._configuration # noqa: SLF001 - - if issubclass(resource_client_class, DatasetClient): - storages_dir = memory_storage_client.datasets_directory - default_id = _get_default_storage_id(configuration, Dataset) - elif issubclass(resource_client_class, KeyValueStoreClient): - storages_dir = memory_storage_client.key_value_stores_directory - default_id = _get_default_storage_id(configuration, KeyValueStore) - elif issubclass(resource_client_class, RequestQueueClient): - storages_dir = memory_storage_client.request_queues_directory - default_id = _get_default_storage_id(configuration, RequestQueue) - else: - raise TypeError('Invalid resource client class.') + storages_dir = memory_storage_client._get_storage_dir(resource_client_class) # noqa: SLF001 + default_id = memory_storage_client._get_default_storage_id(resource_client_class) # noqa: SLF001 # Try to find by name directly from directories if name: diff --git a/src/crawlee/memory_storage_client/_memory_storage_client.py b/src/crawlee/memory_storage_client/_memory_storage_client.py index 70dc932557..03daa78341 100644 --- a/src/crawlee/memory_storage_client/_memory_storage_client.py +++ b/src/crawlee/memory_storage_client/_memory_storage_client.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from crawlee.base_storage_client._types import ResourceClient + TResourceClient = TypeVar('TResourceClient', DatasetClient, KeyValueStoreClient, RequestQueueClient) logger = getLogger(__name__) @@ -45,13 +46,42 @@ class MemoryStorageClient(BaseStorageClient): _TEMPORARY_DIR_NAME = '__CRAWLEE_TEMPORARY' """Name of the directory used to temporarily store files during purges.""" - def __init__(self, configuration: Configuration | None = None) -> None: + _DATASETS_DIR_NAME = 'datasets' + """Name of the directory containing datasets.""" + + _KEY_VALUE_STORES_DIR_NAME = 'key_value_stores' + """Name of the directory containing key-value stores.""" + + _REQUEST_QUEUES_DIR_NAME = 'request_queues' + """Name of the directory containing request queues.""" + + def __init__( + self, + *, + write_metadata: bool, + persist_storage: bool, + storage_dir: str, + default_request_queue_id: str, + default_key_value_store_id: str, + default_dataset_id: str, + ) -> None: """A default constructor. Args: - configuration: Configuration object to use. If None, a default instance will be created. + write_metadata: Whether to write metadata to the storage. + persist_storage: Whether to persist the storage. + storage_dir: Path to the storage directory. + default_request_queue_id: The default request queue ID. + default_key_value_store_id: The default key-value store ID. + default_dataset_id: The default dataset ID. """ - self._explicit_configuration = configuration + # Set the internal attributes. + self._write_metadata = write_metadata + self._persist_storage = persist_storage + self._storage_dir = storage_dir + self._default_request_queue_id = default_request_queue_id + self._default_key_value_store_id = default_key_value_store_id + self._default_dataset_id = default_dataset_id self.datasets_handled: list[DatasetClient] = [] self.key_value_stores_handled: list[KeyValueStoreClient] = [] @@ -60,78 +90,79 @@ def __init__(self, configuration: Configuration | None = None) -> None: self._purged_on_start = False # Indicates whether a purge was already performed on this instance. self._purge_lock = asyncio.Lock() - @property - def _configuration(self) -> Configuration: - return self._explicit_configuration or Configuration.get_global_configuration() + @classmethod + def from_config(cls, config: Configuration | None = None) -> MemoryStorageClient: + """Create a new instance based on the provided configuration. + + All the memory storage client parameters are taken from the configuration object. + + Args: + config: The configuration object. + """ + config = config or Configuration.get_global_configuration() + + return cls( + write_metadata=config.write_metadata, + persist_storage=config.persist_storage, + storage_dir=config.storage_dir, + default_request_queue_id=config.default_request_queue_id, + default_key_value_store_id=config.default_key_value_store_id, + default_dataset_id=config.default_dataset_id, + ) @property def write_metadata(self) -> bool: """Whether to write metadata to the storage.""" - return self._configuration.write_metadata + return self._write_metadata @property def persist_storage(self) -> bool: """Whether to persist the storage.""" - return self._configuration.persist_storage + return self._persist_storage @property def storage_dir(self) -> str: """Path to the storage directory.""" - return self._configuration.storage_dir + return self._storage_dir @property def datasets_directory(self) -> str: """Path to the directory containing datasets.""" - return os.path.join(self.storage_dir, 'datasets') + return os.path.join(self.storage_dir, self._DATASETS_DIR_NAME) @property def key_value_stores_directory(self) -> str: """Path to the directory containing key-value stores.""" - return os.path.join(self.storage_dir, 'key_value_stores') + return os.path.join(self.storage_dir, self._KEY_VALUE_STORES_DIR_NAME) @property def request_queues_directory(self) -> str: """Path to the directory containing request queues.""" - return os.path.join(self.storage_dir, 'request_queues') + return os.path.join(self.storage_dir, self._REQUEST_QUEUES_DIR_NAME) @override def dataset(self, id: str) -> DatasetClient: - return DatasetClient( - memory_storage_client=self, - id=id, - ) + return DatasetClient(memory_storage_client=self, id=id) @override def datasets(self) -> DatasetCollectionClient: - return DatasetCollectionClient( - memory_storage_client=self, - ) + return DatasetCollectionClient(memory_storage_client=self) @override def key_value_store(self, id: str) -> KeyValueStoreClient: - return KeyValueStoreClient( - memory_storage_client=self, - id=id, - ) + return KeyValueStoreClient(memory_storage_client=self, id=id) @override def key_value_stores(self) -> KeyValueStoreCollectionClient: - return KeyValueStoreCollectionClient( - memory_storage_client=self, - ) + return KeyValueStoreCollectionClient(memory_storage_client=self) @override def request_queue(self, id: str) -> RequestQueueClient: - return RequestQueueClient( - memory_storage_client=self, - id=id, - ) + return RequestQueueClient(memory_storage_client=self, id=id) @override def request_queues(self) -> RequestQueueCollectionClient: - return RequestQueueCollectionClient( - memory_storage_client=self, - ) + return RequestQueueCollectionClient(memory_storage_client=self) @override async def purge_on_start(self) -> None: @@ -150,7 +181,10 @@ async def purge_on_start(self) -> None: self._purged_on_start = True def get_cached_resource_client( - self, resource_client_class: type[TResourceClient], id: str | None, name: str | None + self, + resource_client_class: type[TResourceClient], + id: str | None, + name: str | None, ) -> TResourceClient | None: """Try to return a resource client from the internal cache.""" if issubclass(resource_client_class, DatasetClient): @@ -197,14 +231,14 @@ async def _purge_default_storages(self) -> None: self._TEMPORARY_DIR_NAME ) or key_value_store_folder.name.startswith('__OLD'): await self._batch_remove_files(key_value_store_folder.path) - elif key_value_store_folder.name == self._configuration.default_key_value_store_id: + elif key_value_store_folder.name == self._default_key_value_store_id: await self._handle_default_key_value_store(key_value_store_folder.path) # Datasets if await asyncio.to_thread(os.path.exists, self.datasets_directory): dataset_folders = await asyncio.to_thread(os.scandir, self.datasets_directory) for dataset_folder in dataset_folders: - if dataset_folder.name == self._configuration.default_dataset_id or dataset_folder.name.startswith( + if dataset_folder.name == self._default_dataset_id or dataset_folder.name.startswith( self._TEMPORARY_DIR_NAME ): await self._batch_remove_files(dataset_folder.path) @@ -213,9 +247,8 @@ async def _purge_default_storages(self) -> None: if await asyncio.to_thread(os.path.exists, self.request_queues_directory): request_queue_folders = await asyncio.to_thread(os.scandir, self.request_queues_directory) for request_queue_folder in request_queue_folders: - if ( - request_queue_folder.name == self._configuration.default_request_queue_id - or request_queue_folder.name.startswith(self._TEMPORARY_DIR_NAME) + if request_queue_folder.name == self._default_request_queue_id or request_queue_folder.name.startswith( + self._TEMPORARY_DIR_NAME ): await self._batch_remove_files(request_queue_folder.path) @@ -295,3 +328,29 @@ async def _batch_remove_files(self, folder: str, counter: int = 0) -> None: await asyncio.to_thread(shutil.rmtree, temporary_folder, ignore_errors=True) return None + + def _get_default_storage_id(self, storage_client_class: type[TResourceClient]) -> str: + """Get the default storage ID based on the storage class.""" + if issubclass(storage_client_class, DatasetClient): + return self._default_dataset_id + + if issubclass(storage_client_class, KeyValueStoreClient): + return self._default_key_value_store_id + + if issubclass(storage_client_class, RequestQueueClient): + return self._default_request_queue_id + + raise ValueError(f'Invalid storage class: {storage_client_class.__name__}') + + def _get_storage_dir(self, storage_client_class: type[TResourceClient]) -> str: + """Get the storage directory based on the storage class.""" + if issubclass(storage_client_class, DatasetClient): + return self.datasets_directory + + if issubclass(storage_client_class, KeyValueStoreClient): + return self.key_value_stores_directory + + if issubclass(storage_client_class, RequestQueueClient): + return self.request_queues_directory + + raise ValueError(f'Invalid storage class: {storage_client_class.__name__}') diff --git a/src/crawlee/service_container.py b/src/crawlee/service_container.py deleted file mode 100644 index afbe734d06..0000000000 --- a/src/crawlee/service_container.py +++ /dev/null @@ -1,139 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Literal - -from typing_extensions import NotRequired, TypedDict - -from crawlee._utils.docs import docs_group -from crawlee.configuration import Configuration -from crawlee.errors import ServiceConflictError -from crawlee.events import LocalEventManager -from crawlee.memory_storage_client import MemoryStorageClient - -if TYPE_CHECKING: - from crawlee.base_storage_client._base_storage_client import BaseStorageClient - from crawlee.events._event_manager import EventManager - -__all__ = [ - 'get_configuration', - 'get_configuration_if_set', - 'get_event_manager', - 'get_storage_client', - 'set_cloud_storage_client', - 'set_configuration', - 'set_default_storage_client_type', - 'set_event_manager', - 'set_local_storage_client', -] - -StorageClientType = Literal['cloud', 'local'] - - -class _Services(TypedDict): - local_storage_client: NotRequired[BaseStorageClient] - cloud_storage_client: NotRequired[BaseStorageClient] - configuration: NotRequired[Configuration] - event_manager: NotRequired[EventManager] - - -_services = _Services() -_default_storage_client_type: StorageClientType = 'local' - - -@docs_group('Functions') -def get_storage_client(*, client_type: StorageClientType | None = None) -> BaseStorageClient: - """Get the storage client instance for the current environment. - - Args: - client_type: Allows retrieving a specific storage client type, regardless of where we are running. - - Returns: - The current storage client instance. - """ - if client_type is None: - client_type = _default_storage_client_type - - if client_type == 'cloud': - if 'cloud_storage_client' not in _services: - raise RuntimeError('Cloud client was not provided.') - return _services['cloud_storage_client'] - - if 'local_storage_client' not in _services: - _services['local_storage_client'] = MemoryStorageClient() - - return _services['local_storage_client'] - - -@docs_group('Functions') -def set_local_storage_client(local_client: BaseStorageClient) -> None: - """Set the local storage client instance. - - Args: - local_client: The local storage client instance. - """ - if (existing_service := _services.get('local_storage_client')) and existing_service is not local_client: - raise ServiceConflictError('local_storage_client', local_client, existing_service) - - _services['local_storage_client'] = local_client - - -@docs_group('Functions') -def set_cloud_storage_client(cloud_client: BaseStorageClient) -> None: - """Set the cloud storage client instance. - - Args: - cloud_client: The cloud storage client instance. - """ - if (existing_service := _services.get('cloud_storage_client')) and existing_service is not cloud_client: - raise ServiceConflictError('cloud_storage_client', cloud_client, existing_service) - - _services['cloud_storage_client'] = cloud_client - - -@docs_group('Functions') -def set_default_storage_client_type(client_type: StorageClientType) -> None: - """Set the default storage client type.""" - global _default_storage_client_type # noqa: PLW0603 - _default_storage_client_type = client_type - - -@docs_group('Functions') -def get_configuration() -> Configuration: - """Get the configuration object.""" - if 'configuration' not in _services: - _services['configuration'] = Configuration() - - return _services['configuration'] - - -@docs_group('Functions') -def get_configuration_if_set() -> Configuration | None: - """Get the configuration object, or None if it hasn't been set yet.""" - return _services.get('configuration') - - -@docs_group('Functions') -def set_configuration(configuration: Configuration) -> None: - """Set the configuration object.""" - if (existing_service := _services.get('configuration')) and existing_service is not configuration: - raise ServiceConflictError('configuration', configuration, existing_service) - - _services['configuration'] = configuration - - -@docs_group('Functions') -def get_event_manager() -> EventManager: - """Get the event manager.""" - if 'event_manager' not in _services: - _services['event_manager'] = LocalEventManager() - - return _services['event_manager'] - - -@docs_group('Functions') -def set_event_manager(event_manager: EventManager) -> None: - """Set the event manager.""" - if (existing_service := _services.get('event_manager')) and existing_service is not event_manager: - raise ServiceConflictError('event_manager', event_manager, existing_service) - - _services['event_manager'] = event_manager diff --git a/src/crawlee/sessions/_session_pool.py b/src/crawlee/sessions/_session_pool.py index 612b848c73..c5dd99c066 100644 --- a/src/crawlee/sessions/_session_pool.py +++ b/src/crawlee/sessions/_session_pool.py @@ -6,6 +6,7 @@ from logging import getLogger from typing import TYPE_CHECKING, Callable, Literal, overload +from crawlee import service_locator from crawlee._utils.context import ensure_context from crawlee._utils.docs import docs_group from crawlee.events._types import Event, EventPersistStateData @@ -53,10 +54,12 @@ def __init__( persist_state_kvs_name: The name of the `KeyValueStore` used for state persistence. persist_state_key: The key under which the session pool's state is stored in the `KeyValueStore`. """ + if event_manager: + service_locator.set_event_manager(event_manager) + self._max_pool_size = max_pool_size self._session_settings = create_session_settings or {} self._create_session_function = create_session_function - self._event_manager = event_manager self._persistence_enabled = persistence_enabled self._persist_state_kvs_name = persist_state_kvs_name self._persist_state_key = persist_state_key @@ -64,9 +67,6 @@ def __init__( if self._create_session_function and self._session_settings: raise ValueError('Both `create_session_settings` and `create_session_function` cannot be provided.') - if self._persistence_enabled and not self._event_manager: - raise ValueError('Persistence is enabled, but no event manager was provided.') - # Internal non-configurable attributes self._kvs: KeyValueStore | None = None self._sessions: dict[str, Session] = {} @@ -109,7 +109,8 @@ async def __aenter__(self) -> SessionPool: self._active = True - if self._persistence_enabled and self._event_manager: + if self._persistence_enabled: + event_manager = service_locator.get_event_manager() self._kvs = await KeyValueStore.open(name=self._persist_state_kvs_name) # Attempt to restore the previously persisted state. @@ -120,7 +121,7 @@ async def __aenter__(self) -> SessionPool: await self._fill_sessions_to_max() # Register an event listener for persisting the session pool state. - self._event_manager.on(event=Event.PERSIST_STATE, listener=self._persist_state) + event_manager.on(event=Event.PERSIST_STATE, listener=self._persist_state) # If persistence is disabled, just fill the pool with sessions. else: await self._fill_sessions_to_max() @@ -141,9 +142,10 @@ async def __aexit__( if not self._active: raise RuntimeError(f'The {self.__class__.__name__} is not active.') - if self._persistence_enabled and self._event_manager: + if self._persistence_enabled: + event_manager = service_locator.get_event_manager() # Remove the event listener for state persistence. - self._event_manager.off(event=Event.PERSIST_STATE, listener=self._persist_state) + event_manager.off(event=Event.PERSIST_STATE, listener=self._persist_state) # Persist the final state of the session pool. await self._persist_state(event_data=EventPersistStateData(is_migrating=False)) diff --git a/src/crawlee/statistics/_statistics.py b/src/crawlee/statistics/_statistics.py index 704cfc9ab8..00612f4753 100644 --- a/src/crawlee/statistics/_statistics.py +++ b/src/crawlee/statistics/_statistics.py @@ -8,7 +8,7 @@ from typing_extensions import Self, TypeVar -import crawlee.service_container +from crawlee import service_locator from crawlee._utils.context import ensure_context from crawlee._utils.docs import docs_group from crawlee._utils.recurring_task import RecurringTask @@ -20,8 +20,6 @@ if TYPE_CHECKING: from types import TracebackType - from crawlee.events import EventManager - TStatisticsState = TypeVar('TStatisticsState', bound=StatisticsState, default=StatisticsState) logger = getLogger(__name__) @@ -67,7 +65,6 @@ class Statistics(Generic[TStatisticsState]): def __init__( self, *, - event_manager: EventManager | None = None, persistence_enabled: bool = False, persist_state_kvs_name: str = 'default', persist_state_key: str | None = None, @@ -88,8 +85,6 @@ def __init__( self.error_tracker = ErrorTracker() self.error_tracker_retry = ErrorTracker() - self._events = event_manager or crawlee.service_container.get_event_manager() - self._requests_in_progress = dict[str, RequestProcessingRecord]() if persist_state_key is None: @@ -131,7 +126,8 @@ async def __aenter__(self) -> Self: self._key_value_store = await KeyValueStore.open(name=self._persist_state_kvs_name) await self._maybe_load_statistics() - self._events.on(event=Event.PERSIST_STATE, listener=self._persist_state) + event_manager = service_locator.get_event_manager() + event_manager.on(event=Event.PERSIST_STATE, listener=self._persist_state) self._periodic_logger.start() return self @@ -151,7 +147,8 @@ async def __aexit__( raise RuntimeError(f'The {self.__class__.__name__} is not active.') self.state.crawler_finished_at = datetime.now(timezone.utc) - self._events.off(event=Event.PERSIST_STATE, listener=self._persist_state) + event_manager = service_locator.get_event_manager() + event_manager.off(event=Event.PERSIST_STATE, listener=self._persist_state) await self._periodic_logger.stop() await self._persist_state(event_data=EventPersistStateData(is_migrating=False)) self._active = False diff --git a/src/crawlee/storages/_base_storage.py b/src/crawlee/storages/_base_storage.py index 46976810e7..bc2388a2eb 100644 --- a/src/crawlee/storages/_base_storage.py +++ b/src/crawlee/storages/_base_storage.py @@ -4,6 +4,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: + from crawlee.base_storage_client import BaseStorageClient from crawlee.configuration import Configuration @@ -28,15 +29,18 @@ async def open( id: str | None = None, name: str | None = None, configuration: Configuration | None = None, + storage_client: BaseStorageClient | None = None, ) -> BaseStorage: """Open a storage, either restore existing or create a new one. Args: id: The storage ID. name: The storage name. - configuration: The configuration to use. + configuration: Configuration object used during the storage creation or restoration process. + storage_client: Underlying storage client to use. If not provided, the default global storage client + from the service locator will be used. """ @abstractmethod async def drop(self) -> None: - """Drop the storage. Remove it from underlying storage and delete from cache.""" + """Drop the storage, removing it from the underlying storage client and clearing the cache.""" diff --git a/src/crawlee/storages/_creation_management.py b/src/crawlee/storages/_creation_management.py index 4e6a636f14..adbde8c583 100644 --- a/src/crawlee/storages/_creation_management.py +++ b/src/crawlee/storages/_creation_management.py @@ -3,14 +3,13 @@ import asyncio from typing import TYPE_CHECKING, TypeVar -from crawlee import service_container -from crawlee.configuration import Configuration from crawlee.memory_storage_client import MemoryStorageClient from crawlee.storages import Dataset, KeyValueStore, RequestQueue if TYPE_CHECKING: from crawlee.base_storage_client import BaseStorageClient from crawlee.base_storage_client._types import ResourceClient, ResourceCollectionClient + from crawlee.configuration import Configuration TResource = TypeVar('TResource', Dataset, KeyValueStore, RequestQueue) @@ -122,15 +121,12 @@ def _get_default_storage_id(configuration: Configuration, storage_class: type[TR async def open_storage( *, storage_class: type[TResource], - storage_client: BaseStorageClient | None = None, - configuration: Configuration | None = None, - id: str | None = None, - name: str | None = None, + id: str | None, + name: str | None, + configuration: Configuration, + storage_client: BaseStorageClient, ) -> TResource: """Open either a new storage or restore an existing one and return it.""" - configuration = configuration or Configuration.get_global_configuration() - storage_client = storage_client or service_container.get_storage_client() - # Try to restore the storage from cache by name if name: cached_storage = _get_from_cache_by_name(storage_class=storage_class, name=name) @@ -171,21 +167,7 @@ async def open_storage( resource_collection_client = _get_resource_collection_client(storage_class, storage_client) storage_info = await resource_collection_client.get_or_create(name=name) - if issubclass(storage_class, RequestQueue): - storage = storage_class( - id=storage_info.id, - name=storage_info.name, - configuration=configuration, - client=storage_client, - event_manager=service_container.get_event_manager(), - ) - else: - storage = storage_class( - id=storage_info.id, - name=storage_info.name, - configuration=configuration, - client=storage_client, - ) + storage = storage_class(id=storage_info.id, name=storage_info.name, storage_client=storage_client) # Cache the storage by ID and name _add_to_cache_by_id(storage.id, storage) diff --git a/src/crawlee/storages/_dataset.py b/src/crawlee/storages/_dataset.py index 457d8ba43d..4e8b718b3a 100644 --- a/src/crawlee/storages/_dataset.py +++ b/src/crawlee/storages/_dataset.py @@ -8,6 +8,7 @@ from typing_extensions import NotRequired, Required, Unpack, override +from crawlee import service_locator from crawlee._utils.byte_size import ByteSize from crawlee._utils.docs import docs_group from crawlee._utils.file import json_dumps @@ -23,7 +24,6 @@ from crawlee.base_storage_client._models import DatasetItemsListPage from crawlee.configuration import Configuration - logger = logging.getLogger(__name__) @@ -193,20 +193,13 @@ class Dataset(BaseStorage): _EFFECTIVE_LIMIT_SIZE = _MAX_PAYLOAD_SIZE - (_MAX_PAYLOAD_SIZE * _SAFETY_BUFFER_PERCENT) """Calculated payload limit considering safety buffer.""" - def __init__( - self, - id: str, - name: str | None, - configuration: Configuration, - client: BaseStorageClient, - ) -> None: + def __init__(self, id: str, name: str | None, storage_client: BaseStorageClient) -> None: self._id = id self._name = name - self._configuration = configuration - # Get resource clients from storage client - self._resource_client = client.dataset(self._id) - self._resource_collection_client = client.datasets() + # Get resource clients from the storage client. + self._resource_client = storage_client.dataset(self._id) + self._resource_collection_client = storage_client.datasets() @property @override @@ -230,6 +223,9 @@ async def open( ) -> Dataset: from crawlee.storages._creation_management import open_storage + configuration = configuration or service_locator.get_configuration() + storage_client = storage_client or service_locator.get_storage_client() + return await open_storage( storage_class=cls, id=id, diff --git a/src/crawlee/storages/_key_value_store.py b/src/crawlee/storages/_key_value_store.py index f10834a869..7e36a3d576 100644 --- a/src/crawlee/storages/_key_value_store.py +++ b/src/crawlee/storages/_key_value_store.py @@ -1,11 +1,13 @@ from __future__ import annotations +from collections.abc import AsyncIterator from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, overload from typing_extensions import override +from crawlee import service_locator from crawlee._utils.docs import docs_group -from crawlee.base_storage_client._models import KeyValueStoreKeyInfo, KeyValueStoreMetadata +from crawlee.base_storage_client import BaseStorageClient, KeyValueStoreKeyInfo, KeyValueStoreMetadata from crawlee.events._types import Event, EventPersistStateData from crawlee.storages._base_storage import BaseStorage @@ -15,7 +17,6 @@ from crawlee._types import JsonSerializable from crawlee.base_storage_client import BaseStorageClient from crawlee.configuration import Configuration - from crawlee.events._event_manager import EventManager T = TypeVar('T') @@ -59,19 +60,12 @@ class KeyValueStore(BaseStorage): _general_cache: ClassVar[dict[str, dict[str, dict[str, JsonSerializable]]]] = {} _persist_state_event_started = False - def __init__( - self, - id: str, - name: str | None, - configuration: Configuration, - client: BaseStorageClient, - ) -> None: + def __init__(self, id: str, name: str | None, storage_client: BaseStorageClient) -> None: self._id = id self._name = name - self._configuration = configuration # Get resource clients from storage client - self._resource_client = client.key_value_store(self._id) + self._resource_client = storage_client.key_value_store(self._id) @property @override @@ -99,6 +93,9 @@ async def open( ) -> KeyValueStore: from crawlee.storages._creation_management import open_storage + configuration = configuration or service_locator.get_configuration() + storage_client = storage_client or service_locator.get_storage_client() + return await open_storage( storage_class=cls, id=id, @@ -185,7 +182,9 @@ async def get_public_url(self, key: str) -> str: return await self._resource_client.get_public_url(key) async def get_auto_saved_value( - self, key: str, default_value: dict[str, JsonSerializable] | None = None + self, + key: str, + default_value: dict[str, JsonSerializable] | None = None, ) -> dict[str, JsonSerializable]: """Gets a value from KVS that will be automatically saved on changes. @@ -226,18 +225,12 @@ async def _persist_save(self, _event_data: EventPersistStateData | None = None) for key, value in self._cache.items(): await self.set_value(key, value) - def _get_event_manager(self) -> EventManager: - """Get event manager from crawlee services.""" - from crawlee.service_container import get_event_manager - - return get_event_manager() # type: ignore[no-any-return] # Mypy is broken - def _ensure_persist_event(self) -> None: """Setup persist state event handling if not already done.""" if self._persist_state_event_started: return - event_manager = self._get_event_manager() + event_manager = service_locator.get_event_manager() event_manager.on(event=Event.PERSIST_STATE, listener=self._persist_save) self._persist_state_event_started = True @@ -248,7 +241,7 @@ def _clear_cache(self) -> None: def _drop_persist_state_event(self) -> None: """Off event_manager listener and drop event status.""" if self._persist_state_event_started: - event_manager = self._get_event_manager() + event_manager = service_locator.get_event_manager() event_manager.off(event=Event.PERSIST_STATE, listener=self._persist_save) self._persist_state_event_started = False diff --git a/src/crawlee/storages/_request_queue.py b/src/crawlee/storages/_request_queue.py index 8d26880311..4f0d746c50 100644 --- a/src/crawlee/storages/_request_queue.py +++ b/src/crawlee/storages/_request_queue.py @@ -9,12 +9,13 @@ from typing_extensions import override +from crawlee import service_locator from crawlee._utils.crypto import crypto_random_object_id from crawlee._utils.docs import docs_group from crawlee._utils.lru_cache import LRUCache from crawlee._utils.requests import unique_key_to_request_id from crawlee._utils.wait import wait_for_all_tasks_for_finish -from crawlee.base_storage_client._models import ProcessedRequest, RequestQueueMetadata +from crawlee.base_storage_client import BaseStorageClient, ProcessedRequest, RequestQueueMetadata from crawlee.events._types import Event from crawlee.storages._base_storage import BaseStorage from crawlee.storages._request_provider import RequestProvider @@ -23,9 +24,7 @@ from collections.abc import Sequence from crawlee._request import Request - from crawlee.base_storage_client import BaseStorageClient from crawlee.configuration import Configuration - from crawlee.events import EventManager logger = getLogger(__name__) @@ -105,21 +104,16 @@ class RequestQueue(BaseStorage, RequestProvider): _STORAGE_CONSISTENCY_DELAY = timedelta(seconds=3) """Expected delay for storage to achieve consistency, guiding the timing of subsequent read operations.""" - def __init__( - self, - id: str, - name: str | None, - configuration: Configuration, - client: BaseStorageClient, - event_manager: EventManager, - ) -> None: + def __init__(self, id: str, name: str | None, storage_client: BaseStorageClient) -> None: + config = service_locator.get_configuration() + event_manager = service_locator.get_event_manager() + self._id = id self._name = name - self._configuration = configuration # Get resource clients from storage client - self._resource_client = client.request_queue(self._id) - self._resource_collection_client = client.request_queues() + self._resource_client = storage_client.request_queue(self._id) + self._resource_collection_client = storage_client.request_queues() self._request_lock_time = timedelta(minutes=3) self._queue_paused_for_migration = False @@ -131,7 +125,7 @@ def __init__( # Other internal attributes self._tasks = list[asyncio.Task]() self._client_key = crypto_random_object_id() - self._internal_timeout = configuration.internal_timeout or timedelta(minutes=5) + self._internal_timeout = config.internal_timeout or timedelta(minutes=5) self._assumed_total_count = 0 self._assumed_handled_count = 0 self._queue_head_dict: OrderedDict[str, str] = OrderedDict() @@ -163,6 +157,9 @@ async def open( ) -> RequestQueue: from crawlee.storages._creation_management import open_storage + configuration = configuration or service_locator.get_configuration() + storage_client = storage_client or service_locator.get_storage_client() + return await open_storage( storage_class=cls, id=id, diff --git a/tests/unit/_autoscaling/test_snapshotter.py b/tests/unit/_autoscaling/test_snapshotter.py index ce2c71f029..3c491c5cd2 100644 --- a/tests/unit/_autoscaling/test_snapshotter.py +++ b/tests/unit/_autoscaling/test_snapshotter.py @@ -3,22 +3,21 @@ from datetime import datetime, timedelta, timezone from logging import getLogger from typing import cast -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import MagicMock import pytest +from crawlee import service_locator from crawlee._autoscaling import Snapshotter from crawlee._autoscaling.types import CpuSnapshot, EventLoopSnapshot, Snapshot from crawlee._utils.byte_size import ByteSize from crawlee._utils.system import CpuInfo, MemoryInfo -from crawlee.events import EventManager, LocalEventManager from crawlee.events._types import Event, EventSystemInfoData @pytest.fixture def snapshotter() -> Snapshotter: - mocked_event_manager = AsyncMock(spec=EventManager) - return Snapshotter(mocked_event_manager, available_memory_ratio=0.25) + return Snapshotter(available_memory_ratio=0.25) @pytest.fixture @@ -33,7 +32,7 @@ def event_system_data_info() -> EventSystemInfoData: async def test_start_stop_lifecycle() -> None: - async with LocalEventManager() as event_manager, Snapshotter(event_manager, available_memory_ratio=0.25): + async with Snapshotter(available_memory_ratio=0.25): pass @@ -94,8 +93,7 @@ async def test_get_cpu_sample(snapshotter: Snapshotter) -> None: async def test_methods_raise_error_when_not_active() -> None: - event_manager = AsyncMock(spec=EventManager) - snapshotter = Snapshotter(event_manager, available_memory_ratio=0.25) + snapshotter = Snapshotter(available_memory_ratio=0.25) assert snapshotter.active is False @@ -194,7 +192,7 @@ def test_snapshot_pruning_keeps_recent_records_unaffected(snapshotter: Snapshott def test_memory_load_evaluation_logs_warning_on_high_usage(caplog: pytest.LogCaptureFixture) -> None: - snapshotter = Snapshotter(AsyncMock(spec=EventManager), max_memory_size=ByteSize.from_gb(8)) + snapshotter = Snapshotter(max_memory_size=ByteSize.from_gb(8)) high_memory_usage = ByteSize.from_gb(8) * 0.95 # 95% of 8 GB @@ -250,8 +248,8 @@ def create_event_data(creation_time: datetime) -> EventSystemInfoData: ) async with ( - LocalEventManager() as event_manager, - Snapshotter(event_manager, available_memory_ratio=0.25) as snapshotter, + service_locator.get_event_manager() as event_manager, + Snapshotter(available_memory_ratio=0.25) as snapshotter, ): event_manager.emit(event=Event.SYSTEM_INFO, event_data=create_event_data(time_new)) await event_manager.wait_for_all_listeners_to_complete() diff --git a/tests/unit/_autoscaling/test_system_status.py b/tests/unit/_autoscaling/test_system_status.py index a79e6be53b..acb6e35314 100644 --- a/tests/unit/_autoscaling/test_system_status.py +++ b/tests/unit/_autoscaling/test_system_status.py @@ -15,7 +15,6 @@ SystemInfo, ) from crawlee._utils.byte_size import ByteSize -from crawlee.events import LocalEventManager if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -23,10 +22,7 @@ @pytest.fixture async def snapshotter() -> AsyncGenerator[Snapshotter, None]: - async with ( - LocalEventManager() as event_manager, - Snapshotter(event_manager, available_memory_ratio=0.25) as snapshotter, - ): + async with Snapshotter(available_memory_ratio=0.25) as snapshotter: yield snapshotter @@ -36,10 +32,7 @@ def now() -> datetime: async def test_start_stop_lifecycle() -> None: - async with ( - LocalEventManager() as event_manager, - Snapshotter(event_manager, available_memory_ratio=0.25) as snapshotter, - ): + async with Snapshotter(available_memory_ratio=0.25) as snapshotter: system_status = SystemStatus(snapshotter) system_status.get_current_system_info() system_status.get_historical_system_info() diff --git a/tests/unit/_memory_storage_client/test_memory_storage_client.py b/tests/unit/_memory_storage_client/test_memory_storage_client.py index 9ad606427b..882aca7fab 100644 --- a/tests/unit/_memory_storage_client/test_memory_storage_client.py +++ b/tests/unit/_memory_storage_client/test_memory_storage_client.py @@ -8,7 +8,7 @@ import pytest -from crawlee import Request +from crawlee import Request, service_locator from crawlee._consts import METADATA_FILENAME from crawlee.configuration import Configuration from crawlee.memory_storage_client import MemoryStorageClient @@ -17,13 +17,13 @@ async def test_write_metadata(tmp_path: Path) -> None: dataset_name = 'test' dataset_no_metadata_name = 'test-no-metadata' - ms = MemoryStorageClient( + ms = MemoryStorageClient.from_config( Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] write_metadata=True, ), ) - ms_no_metadata = MemoryStorageClient( + ms_no_metadata = MemoryStorageClient.from_config( Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] write_metadata=False, @@ -48,7 +48,7 @@ async def test_write_metadata(tmp_path: Path) -> None: ], ) async def test_persist_storage(persist_storage: bool, tmp_path: Path) -> None: # noqa: FBT001 - ms = MemoryStorageClient( + ms = MemoryStorageClient.from_config( Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] persist_storage=persist_storage, @@ -82,18 +82,20 @@ async def test_persist_storage(persist_storage: bool, tmp_path: Path) -> None: def test_persist_storage_set_to_false_via_string_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: monkeypatch.setenv('CRAWLEE_PERSIST_STORAGE', 'false') - ms = MemoryStorageClient(Configuration(crawlee_storage_dir=str(tmp_path))) # type: ignore[call-arg] + ms = MemoryStorageClient.from_config( + Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] + ) assert ms.persist_storage is False def test_persist_storage_set_to_false_via_numeric_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: monkeypatch.setenv('CRAWLEE_PERSIST_STORAGE', '0') - ms = MemoryStorageClient(Configuration(crawlee_storage_dir=str(tmp_path))) # type: ignore[call-arg] + ms = MemoryStorageClient.from_config(Configuration(crawlee_storage_dir=str(tmp_path))) # type: ignore[call-arg] assert ms.persist_storage is False def test_persist_storage_true_via_constructor_arg(tmp_path: Path) -> None: - ms = MemoryStorageClient( + ms = MemoryStorageClient.from_config( Configuration( crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] persist_storage=True, @@ -104,20 +106,24 @@ def test_persist_storage_true_via_constructor_arg(tmp_path: Path) -> None: def test_default_write_metadata_behavior(tmp_path: Path) -> None: # Default behavior - ms = MemoryStorageClient(Configuration(crawlee_storage_dir=str(tmp_path))) # type: ignore[call-arg] + ms = MemoryStorageClient.from_config( + Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] + ) assert ms.write_metadata is True def test_write_metadata_set_to_false_via_env_var(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: # Test if env var changes write_metadata to False monkeypatch.setenv('CRAWLEE_WRITE_METADATA', 'false') - ms = MemoryStorageClient(Configuration(crawlee_storage_dir=str(tmp_path))) # type: ignore[call-arg] + ms = MemoryStorageClient.from_config( + Configuration(crawlee_storage_dir=str(tmp_path)), # type: ignore[call-arg] + ) assert ms.write_metadata is False def test_write_metadata_false_via_constructor_arg_overrides_env_var(tmp_path: Path) -> None: # Test if constructor arg takes precedence over env var value - ms = MemoryStorageClient( + ms = MemoryStorageClient.from_config( Configuration( write_metadata=False, crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] @@ -127,7 +133,7 @@ def test_write_metadata_false_via_constructor_arg_overrides_env_var(tmp_path: Pa async def test_purge_datasets(tmp_path: Path) -> None: - ms = MemoryStorageClient( + ms = MemoryStorageClient.from_config( Configuration( write_metadata=True, crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] @@ -150,7 +156,7 @@ async def test_purge_datasets(tmp_path: Path) -> None: async def test_purge_key_value_stores(tmp_path: Path) -> None: - ms = MemoryStorageClient( + ms = MemoryStorageClient.from_config( Configuration( write_metadata=True, crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] @@ -185,7 +191,7 @@ async def test_purge_key_value_stores(tmp_path: Path) -> None: async def test_purge_request_queues(tmp_path: Path) -> None: - ms = MemoryStorageClient( + ms = MemoryStorageClient.from_config( Configuration( write_metadata=True, crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] @@ -207,7 +213,7 @@ async def test_purge_request_queues(tmp_path: Path) -> None: async def test_not_implemented_method(tmp_path: Path) -> None: - ms = MemoryStorageClient( + ms = MemoryStorageClient.from_config( Configuration( write_metadata=True, crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] @@ -222,22 +228,29 @@ async def test_not_implemented_method(tmp_path: Path) -> None: async def test_default_storage_path_used(monkeypatch: pytest.MonkeyPatch) -> None: - # We expect the default value to be used + # Reset the configuration in service locator + service_locator._configuration = None + service_locator._configuration_was_set = False + + # Remove the env var for setting the storage directory monkeypatch.delenv('CRAWLEE_STORAGE_DIR', raising=False) - ms = MemoryStorageClient() - assert ms.storage_dir == './storage' + + # Initialize the service locator with default configuration + msc = MemoryStorageClient.from_config() + assert msc.storage_dir == './storage' async def test_storage_path_from_env_var_overrides_default(monkeypatch: pytest.MonkeyPatch) -> None: # We expect the env var to override the default value monkeypatch.setenv('CRAWLEE_STORAGE_DIR', './env_var_storage_dir') - ms = MemoryStorageClient() + service_locator.set_configuration(Configuration()) + ms = MemoryStorageClient.from_config() assert ms.storage_dir == './env_var_storage_dir' async def test_parametrized_storage_path_overrides_env_var() -> None: # We expect the parametrized value to be used - ms = MemoryStorageClient( + ms = MemoryStorageClient.from_config( Configuration(crawlee_storage_dir='./parametrized_storage_dir'), # type: ignore[call-arg] ) assert ms.storage_dir == './parametrized_storage_dir' diff --git a/tests/unit/_memory_storage_client/test_memory_storage_e2e.py b/tests/unit/_memory_storage_client/test_memory_storage_e2e.py index 17db0e95f5..7bf3f3a8a3 100644 --- a/tests/unit/_memory_storage_client/test_memory_storage_e2e.py +++ b/tests/unit/_memory_storage_client/test_memory_storage_e2e.py @@ -5,7 +5,7 @@ import pytest -from crawlee import Request, service_container +from crawlee import Request, service_locator from crawlee.storages._key_value_store import KeyValueStore from crawlee.storages._request_queue import RequestQueue @@ -14,7 +14,7 @@ async def test_actor_memory_storage_client_key_value_store_e2e( monkeypatch: pytest.MonkeyPatch, purge_on_start: bool, # noqa: FBT001 - reset_globals: Callable[[], None], + prepare_test_env: Callable[[], None], ) -> None: """This test simulates two clean runs using memory storage. The second run attempts to access data created by the first one. @@ -22,7 +22,7 @@ async def test_actor_memory_storage_client_key_value_store_e2e( # Configure purging env var monkeypatch.setenv('CRAWLEE_PURGE_ON_START', f'{int(purge_on_start)}') # Store old storage client so we have the object reference for comparison - old_client = service_container.get_storage_client() + old_client = service_locator.get_storage_client() old_default_kvs = await KeyValueStore.open() old_non_default_kvs = await KeyValueStore.open(name='non-default') @@ -32,10 +32,10 @@ async def test_actor_memory_storage_client_key_value_store_e2e( # We simulate another clean run, we expect the memory storage to read from the local data directory # Default storages are purged based on purge_on_start parameter. - reset_globals() + prepare_test_env() # Check if we're using a different memory storage instance - assert old_client is not service_container.get_storage_client() + assert old_client is not service_locator.get_storage_client() default_kvs = await KeyValueStore.open() assert default_kvs is not old_default_kvs non_default_kvs = await KeyValueStore.open(name='non-default') @@ -54,7 +54,7 @@ async def test_actor_memory_storage_client_key_value_store_e2e( async def test_actor_memory_storage_client_request_queue_e2e( monkeypatch: pytest.MonkeyPatch, purge_on_start: bool, # noqa: FBT001 - reset_globals: Callable[[], None], + prepare_test_env: Callable[[], None], ) -> None: """This test simulates two clean runs using memory storage. The second run attempts to access data created by the first one. @@ -82,7 +82,7 @@ async def test_actor_memory_storage_client_request_queue_e2e( # We simulate another clean run, we expect the memory storage to read from the local data directory # Default storages are purged based on purge_on_start parameter. - reset_globals() + prepare_test_env() # Add some more requests to the default queue default_queue = await RequestQueue.open() diff --git a/tests/unit/basic_crawler/test_basic_crawler.py b/tests/unit/basic_crawler/test_basic_crawler.py index 5c9399c699..5eeb5ad939 100644 --- a/tests/unit/basic_crawler/test_basic_crawler.py +++ b/tests/unit/basic_crawler/test_basic_crawler.py @@ -863,22 +863,6 @@ async def handler(context: BasicCrawlingContext) -> None: } -async def test_passes_configuration_to_storages() -> None: - configuration = Configuration(persist_storage=False, purge_on_start=True) - - crawler = BasicCrawler(configuration=configuration) - - dataset = await crawler.get_dataset() - assert dataset._configuration is configuration - - key_value_store = await crawler.get_key_value_store() - assert key_value_store._configuration is configuration - - request_provider = await crawler.get_request_provider() - assert isinstance(request_provider, RequestQueue) - assert request_provider._configuration is configuration - - async def test_respects_no_persist_storage() -> None: configuration = Configuration(persist_storage=False) crawler = BasicCrawler(configuration=configuration) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index cab47bc202..771d7355f9 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -10,7 +10,7 @@ from proxy import Proxy from yarl import URL -from crawlee import service_container +from crawlee import service_locator from crawlee.configuration import Configuration from crawlee.memory_storage_client import MemoryStorageClient from crawlee.proxy_configuration import ProxyInfo @@ -22,15 +22,36 @@ @pytest.fixture -def reset_globals(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Callable[[], None]: - def reset() -> None: - # Set the environment variable for the local storage directory to the temporary path +def prepare_test_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Callable[[], None]: + """Prepare the testing environment by resetting the global state before each test. + + This fixture ensures that the global state of the package is reset to a known baseline before each test runs. + It also configures a temporary storage directory for test isolation. + + Args: + monkeypatch: Test utility provided by pytest for patching. + tmp_path: A unique temporary directory path provided by pytest for test isolation. + + Returns: + A callable that prepares the test environment. + """ + + def _prepare_test_env() -> None: + # Set the environment variable for the local storage directory to the temporary path. monkeypatch.setenv('CRAWLEE_STORAGE_DIR', str(tmp_path)) - # Reset services in crawlee.service_container - cast(dict, service_container._services).clear() + # Reset the flags in the service locator to indicate that no services are explicitly set. This ensures + # a clean state, as services might have been set during a previous test and not reset properly. + service_locator._configuration_was_set = False + service_locator._storage_client_was_set = False + service_locator._event_manager_was_set = False - # Clear creation-related caches to ensure no state is carried over between tests + # Reset the services in the service locator. + service_locator._configuration = None + service_locator._event_manager = None + service_locator._storage_client = None + + # Clear creation-related caches to ensure no state is carried over between tests. monkeypatch.setattr(_creation_management, '_cache_dataset_by_id', {}) monkeypatch.setattr(_creation_management, '_cache_dataset_by_name', {}) monkeypatch.setattr(_creation_management, '_cache_kvs_by_id', {}) @@ -38,34 +59,26 @@ def reset() -> None: monkeypatch.setattr(_creation_management, '_cache_rq_by_id', {}) monkeypatch.setattr(_creation_management, '_cache_rq_by_name', {}) - # Verify that the environment variable is set correctly + # Verify that the test environment was set up correctly. assert os.environ.get('CRAWLEE_STORAGE_DIR') == str(tmp_path) + assert service_locator._configuration_was_set is False + assert service_locator._storage_client_was_set is False + assert service_locator._event_manager_was_set is False - return reset + return _prepare_test_env @pytest.fixture(autouse=True) -def _isolate_test_environment(reset_globals: Callable[[], None]) -> None: - """Isolate tests by resetting the storage clients, clearing caches, and setting the environment variables. +def _isolate_test_environment(prepare_test_env: Callable[[], None]) -> None: + """Isolate the testing environment by resetting global state before and after each test. - The fixture is applied automatically to all test cases. + This fixture ensures that each test starts with a clean slate and that any modifications during the test + do not affect subsequent tests. It runs automatically for all tests. Args: - monkeypatch: Test utility provided by pytest. - tmp_path: A unique temporary directory path provided by pytest for test isolation. + prepare_test_env: Fixture to prepare the environment before each test. """ - - reset_globals() - - -@pytest.fixture -def memory_storage_client(tmp_path: Path) -> MemoryStorageClient: - cfg = Configuration( - write_metadata=True, - persist_storage=True, - crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] - ) - return MemoryStorageClient(cfg) + prepare_test_env() @pytest.fixture @@ -147,3 +160,15 @@ async def disabled_proxy(proxy_info: ProxyInfo) -> AsyncGenerator[ProxyInfo, Non ] ): yield proxy_info + + +@pytest.fixture +def memory_storage_client(tmp_path: Path) -> MemoryStorageClient: + """A fixture for testing the memory storage client and its resource clients.""" + config = Configuration( + persist_storage=True, + write_metadata=True, + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + + return MemoryStorageClient.from_config(config) diff --git a/tests/unit/sessions/test_session_pool.py b/tests/unit/sessions/test_session_pool.py index 4ddc66a71b..21b9e8e9d6 100644 --- a/tests/unit/sessions/test_session_pool.py +++ b/tests/unit/sessions/test_session_pool.py @@ -6,6 +6,7 @@ import pytest +from crawlee import service_locator from crawlee.events import EventManager from crawlee.events._types import Event, EventPersistStateData from crawlee.sessions import Session, SessionPool @@ -113,9 +114,10 @@ async def test_create_session_function() -> None: async def test_session_pool_persist(event_manager: EventManager, kvs: KeyValueStore) -> None: """Test persistence of session pool state to KVS and validate stored data integrity.""" + service_locator.set_event_manager(event_manager) + async with SessionPool( max_pool_size=MAX_POOL_SIZE, - event_manager=event_manager, persistence_enabled=True, persist_state_kvs_name=KVS_NAME, persist_state_key=PERSIST_STATE_KEY, @@ -143,20 +145,20 @@ async def test_session_pool_persist(event_manager: EventManager, kvs: KeyValueSt async def test_session_pool_persist_and_restore(event_manager: EventManager, kvs: KeyValueStore) -> None: """Check session pool's ability to persist its state and then restore it accurately after reset.""" + service_locator.set_event_manager(event_manager) + async with SessionPool( max_pool_size=MAX_POOL_SIZE, - event_manager=event_manager, persistence_enabled=True, persist_state_kvs_name=KVS_NAME, persist_state_key=PERSIST_STATE_KEY, - ) as _: + ): # Emit persist state event and wait for the persistence to complete event_manager.emit(event=Event.PERSIST_STATE, event_data=EventPersistStateData(is_migrating=False)) await event_manager.wait_for_all_listeners_to_complete() async with SessionPool( max_pool_size=MAX_POOL_SIZE, - event_manager=event_manager, persistence_enabled=True, persist_state_kvs_name=KVS_NAME, persist_state_key=PERSIST_STATE_KEY, diff --git a/tests/unit/storages/test_key_value_store.py b/tests/unit/storages/test_key_value_store.py index b6549dde65..b1ba14088b 100644 --- a/tests/unit/storages/test_key_value_store.py +++ b/tests/unit/storages/test_key_value_store.py @@ -27,7 +27,7 @@ async def key_value_store() -> AsyncGenerator[KeyValueStore, None]: @pytest.fixture async def mock_event_manager() -> AsyncGenerator[EventManager, None]: async with EventManager(persist_state_interval=timedelta(milliseconds=50)) as event_manager: - with patch('crawlee.service_container.get_event_manager', return_value=event_manager): + with patch('crawlee.service_locator.get_event_manager', return_value=event_manager): yield event_manager diff --git a/tests/unit/test_configuration.py b/tests/unit/test_configuration.py index bfcf185a5e..308115d3be 100644 --- a/tests/unit/test_configuration.py +++ b/tests/unit/test_configuration.py @@ -1,7 +1,76 @@ +# TODO: Update crawlee_storage_dir args once the Pydantic bug is fixed +# https://github.com/apify/crawlee-python/issues/146 + from __future__ import annotations +from typing import TYPE_CHECKING + +from crawlee import service_locator from crawlee.configuration import Configuration +from crawlee.http_crawler import HttpCrawler, HttpCrawlingContext +from crawlee.memory_storage_client._memory_storage_client import MemoryStorageClient + +if TYPE_CHECKING: + from pathlib import Path + + from yarl import URL def test_global_configuration_works() -> None: - assert Configuration.get_global_configuration() is Configuration.get_global_configuration() + assert ( + Configuration.get_global_configuration() + is Configuration.get_global_configuration() + is service_locator.get_configuration() + is service_locator.get_configuration() + ) + + +def test_global_configuration_works_reversed() -> None: + assert ( + service_locator.get_configuration() + is service_locator.get_configuration() + is Configuration.get_global_configuration() + is Configuration.get_global_configuration() + ) + + +async def test_storage_not_persisted_when_disabled(tmp_path: Path, httpbin: URL) -> None: + config = Configuration( + persist_storage=False, + write_metadata=False, + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + storage_client = MemoryStorageClient.from_config(config) + + crawler = HttpCrawler(storage_client=storage_client) + + @crawler.router.default_handler + async def default_handler(context: HttpCrawlingContext) -> None: + await context.push_data({'url': context.request.url}) + + await crawler.run([str(httpbin)]) + + # Verify that no files were created in the storage directory. + content = list(tmp_path.iterdir()) + assert content == [], 'Expected the storage directory to be empty, but it is not.' + + +async def test_storage_persisted_when_enabled(tmp_path: Path, httpbin: URL) -> None: + config = Configuration( + persist_storage=True, + write_metadata=True, + crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg] + ) + storage_client = MemoryStorageClient.from_config(config) + + crawler = HttpCrawler(storage_client=storage_client) + + @crawler.router.default_handler + async def default_handler(context: HttpCrawlingContext) -> None: + await context.push_data({'url': context.request.url}) + + await crawler.run([str(httpbin)]) + + # Verify that files were created in the storage directory. + content = list(tmp_path.iterdir()) + assert content != [], 'Expected the storage directory to contain files, but it does not.' diff --git a/tests/unit/test_service_container.py b/tests/unit/test_service_container.py deleted file mode 100644 index b2a0b4c1bd..0000000000 --- a/tests/unit/test_service_container.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -from unittest.mock import Mock - -import pytest - -from crawlee import service_container -from crawlee.configuration import Configuration -from crawlee.errors import ServiceConflictError -from crawlee.events._local_event_manager import LocalEventManager -from crawlee.memory_storage_client._memory_storage_client import MemoryStorageClient - - -async def test_get_event_manager() -> None: - event_manager = service_container.get_event_manager() - assert isinstance(event_manager, LocalEventManager) - - -async def test_set_event_manager() -> None: - event_manager = Mock() - service_container.set_event_manager(event_manager) - assert service_container.get_event_manager() is event_manager - - -async def test_overwrite_event_manager() -> None: - event_manager = Mock() - service_container.set_event_manager(event_manager) - service_container.set_event_manager(event_manager) - - with pytest.raises(ServiceConflictError): - service_container.set_event_manager(Mock()) - - -async def test_get_configuration() -> None: - configuration = service_container.get_configuration() - assert isinstance(configuration, Configuration) - - -async def test_set_configuration() -> None: - configuration = Mock() - service_container.set_configuration(configuration) - assert service_container.get_configuration() is configuration - - -async def test_overwrite_configuration() -> None: - configuration = Mock() - service_container.set_configuration(configuration) - service_container.set_configuration(configuration) - - with pytest.raises(ServiceConflictError): - service_container.set_configuration(Mock()) - - -async def test_get_storage_client() -> None: - storage_client = service_container.get_storage_client() - assert isinstance(storage_client, MemoryStorageClient) - - with pytest.raises(RuntimeError): - service_container.get_storage_client(client_type='cloud') - - service_container.set_default_storage_client_type('cloud') - - with pytest.raises(RuntimeError): - service_container.get_storage_client() - - storage_client = service_container.get_storage_client(client_type='local') - assert isinstance(storage_client, MemoryStorageClient) - - cloud_client = Mock() - service_container.set_cloud_storage_client(cloud_client) - assert service_container.get_storage_client(client_type='cloud') is cloud_client - assert service_container.get_storage_client() is cloud_client - - -async def test_reset_local_storage_client() -> None: - storage_client = Mock() - - service_container.set_local_storage_client(storage_client) - service_container.set_local_storage_client(storage_client) - - with pytest.raises(ServiceConflictError): - service_container.set_local_storage_client(Mock()) - - -async def test_reset_cloud_storage_client() -> None: - storage_client = Mock() - - service_container.set_cloud_storage_client(storage_client) - service_container.set_cloud_storage_client(storage_client) - - with pytest.raises(ServiceConflictError): - service_container.set_cloud_storage_client(Mock()) diff --git a/tests/unit/test_service_locator.py b/tests/unit/test_service_locator.py new file mode 100644 index 0000000000..435a10180f --- /dev/null +++ b/tests/unit/test_service_locator.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import pytest + +from crawlee import service_locator +from crawlee.configuration import Configuration +from crawlee.errors import ServiceConflictError +from crawlee.events import LocalEventManager +from crawlee.memory_storage_client import MemoryStorageClient + + +def test_configuration() -> None: + default_config = Configuration() + config = service_locator.get_configuration() + assert config == default_config + + custom_config = Configuration(default_browser_path='custom_path') + service_locator.set_configuration(custom_config) + config = service_locator.get_configuration() + assert config == custom_config + + with pytest.raises(ServiceConflictError, match='Configuration has already been set.'): + service_locator.set_configuration(custom_config) + + +def test_event_manager() -> None: + default_event_manager = service_locator.get_event_manager() + assert isinstance(default_event_manager, LocalEventManager) + + custom_event_manager = LocalEventManager() + service_locator.set_event_manager(custom_event_manager) + event_manager = service_locator.get_event_manager() + assert event_manager == custom_event_manager + + with pytest.raises(ServiceConflictError, match='EventManager has already been set.'): + service_locator.set_event_manager(custom_event_manager) + + +def test_storage_client() -> None: + default_storage_client = service_locator.get_storage_client() + assert isinstance(default_storage_client, MemoryStorageClient) + + custom_storage_client = MemoryStorageClient.from_config() + service_locator.set_storage_client(custom_storage_client) + storage_client = service_locator.get_storage_client() + assert storage_client == custom_storage_client + + with pytest.raises(ServiceConflictError, match='StorageClient has already been set.'): + service_locator.set_storage_client(custom_storage_client)