diff --git a/src/fromager/bootstrapper.py b/src/fromager/bootstrapper.py index ae0c17da..7635272e 100644 --- a/src/fromager/bootstrapper.py +++ b/src/fromager/bootstrapper.py @@ -913,9 +913,11 @@ def _resolve_from_version_source( return None try: # no need to pass req type to enable caching since we are already using the graph as our cache + # do not cache candidates provider = resolver.GenericProvider( - version_source=lambda x, y, z: version_source, + version_source=lambda identifier: version_source, constraints=self.ctx.constraints, + use_resolver_cache=False, ) return resolver.resolve_from_provider(provider, req) except Exception as err: diff --git a/src/fromager/resolver.py b/src/fromager/resolver.py index 187d7f74..69bbcaa4 100644 --- a/src/fromager/resolver.py +++ b/src/fromager/resolver.py @@ -10,7 +10,6 @@ import os import re import typing -from collections import defaultdict from collections.abc import Iterable from operator import attrgetter from platform import python_version @@ -180,7 +179,7 @@ def get_project_from_pypi( extras: typing.Iterable[str], sdist_server_url: str, ignore_platform: bool = False, -) -> typing.Iterable[Candidate]: +) -> Candidates: """Return candidates created from the project name and extras.""" found_candidates: set[str] = set() ignored_candidates: set[str] = set() @@ -330,34 +329,62 @@ def get_project_from_pypi( RequirementsMap: typing.TypeAlias = typing.Mapping[str, typing.Iterable[Requirement]] -CandidatesMap: typing.TypeAlias = typing.Mapping[str, typing.Iterable[Candidate]] +Candidates: typing.TypeAlias = typing.Iterable[Candidate] +CandidatesMap: typing.TypeAlias = typing.Mapping[str, Candidates] +# {identifier: [cls, cachekey]: list[candidates]}} +ResolverCache: typing.TypeAlias = dict[ + str, dict[tuple[type[ExtrasProvider], str], list[Candidate]] +] VersionSource: typing.TypeAlias = typing.Callable[ - [str, RequirementsMap, CandidatesMap], + [str], typing.Iterable[tuple[str, str | Version]], ] class BaseProvider(ExtrasProvider): + resolver_cache: typing.ClassVar[ResolverCache] = {} + def __init__( self, - include_sdists: bool = True, - include_wheels: bool = True, - sdist_server_url: str = "https://pypi.org/simple/", + *, constraints: Constraints | None = None, req_type: RequirementType | None = None, - ignore_platform: bool = False, + use_resolver_cache: bool = True, ): super().__init__() - self.include_sdists = include_sdists - self.include_wheels = include_wheels - self.sdist_server_url = sdist_server_url self.constraints = constraints or Constraints() self.req_type = req_type - self.ignore_platform = ignore_platform + self.use_cache_candidates = use_resolver_cache + + @property + def cache_key(self) -> str: + """Return cache key for the provider + + The cache key must be unique for each provider configuration, e.g. + PyPI URL, GitHub org + repo, ... + """ + raise NotImplementedError() + + def find_candidates(self, identifier: str) -> Candidates: + """Find unfiltered candidates""" + raise NotImplementedError() def identify(self, requirement_or_candidate: Requirement | Candidate) -> str: return canonicalize_name(requirement_or_candidate.name) + @classmethod + def clear_cache(cls, identifier: str | None = None) -> None: + """Clear global resolver cache + + ``None`` clears all caches, an ``identifier`` string clears the + cache for an identifier. Raises :exc:`KeyError` for unknown + identifiers. + """ + if identifier is None: + cls.resolver_cache.clear() + else: + cls.resolver_cache.pop(canonicalize_name(identifier)) + def get_extras_for( self, requirement_or_candidate: Requirement | Candidate, @@ -391,31 +418,6 @@ def validate_candidate( return True return False - def get_cache(self) -> dict[str, list[Candidate]]: - raise NotImplementedError() - - def get_from_cache( - self, - identifier: str, - requirements: RequirementsMap, - incompatibilities: CandidatesMap, - ) -> list[Candidate]: - cache = self.get_cache() - # we only want caching for build reqs because for install time reqs we always want to get the latest version - # we can't guarantee that the latest version is available in the cache so install time reqs cannot use the cache - if self.req_type is None or not self.req_type.is_build_requirement: - return [] - return [ - c - for c in cache[identifier] - if self.validate_candidate(identifier, requirements, incompatibilities, c) - ] - - def add_to_cache(self, identifier: str, candidates: list[Candidate]) -> None: - # we can add candidates to cache even for install type reqs because build time reqs are - # allowed to use candidates seen when we were resolving the same req as an install type - self.get_cache()[identifier].extend(candidates) - def get_preference( self, identifier: str, @@ -459,20 +461,67 @@ def get_dependencies(self, candidate: Candidate) -> list[Requirement]: # return candidate.dependencies return [] + def _get_cached_candidates(self, identifier: str) -> list[Candidate]: + """Get list of cached candidates for identifier and provider + + The method always returns a list. If the cache did not have an entry + before, a new empty list is stored in the cache and returned to the + caller. The caller can mutate the list in place to update the cache. + """ + cls = type(self) + provider_cache = cls.resolver_cache.setdefault(identifier, {}) + candidate_cache = provider_cache.setdefault((cls, self.cache_key), []) + return candidate_cache + + def _find_cached_candidates(self, identifier: str) -> Candidates: + """Find candidates with caching""" + if self.use_cache_candidates: + cached_candidates = self._get_cached_candidates(identifier) + if cached_candidates: + logger.debug( + "%s: use %i cached candidates", + identifier, + len(cached_candidates), + ) + return cached_candidates + candidates = list(self.find_candidates(identifier)) + if self.use_cache_candidates: + # mutate list object in-place + cached_candidates[:] = candidates + logger.debug( + "%s: cache %i unfiltered candidates", + identifier, + len(candidates), + ) + else: + logger.debug( + "%s: got %i unfiltered candidates, ignoring cache", + identifier, + len(candidates), + ) + return candidates + def find_matches( self, identifier: str, requirements: RequirementsMap, incompatibilities: CandidatesMap, - ) -> typing.Iterable[Candidate]: - raise NotImplementedError() + ) -> Candidates: + """Find matching candidates, sorted by version and build tag""" + unfiltered_candidates = self._find_cached_candidates(identifier) + candidates = [ + candidate + for candidate in unfiltered_candidates + if self.validate_candidate( + identifier, requirements, incompatibilities, candidate + ) + ] + return sorted(candidates, key=attrgetter("version", "build_tag"), reverse=True) class PyPIProvider(BaseProvider): """Lookup package and versions from a simple Python index (PyPI)""" - pypi_resolver_cache: typing.ClassVar[dict[str, list[Candidate]]] = defaultdict(list) - def __init__( self, include_sdists: bool = True, @@ -481,18 +530,34 @@ def __init__( constraints: Constraints | None = None, req_type: RequirementType | None = None, ignore_platform: bool = False, + *, + use_resolver_cache: bool = True, ): super().__init__( - include_sdists=include_sdists, - include_wheels=include_wheels, - sdist_server_url=sdist_server_url, constraints=constraints, req_type=req_type, - ignore_platform=ignore_platform, + use_resolver_cache=use_resolver_cache, ) + self.include_sdists = include_sdists + self.include_wheels = include_wheels + self.sdist_server_url = sdist_server_url + self.ignore_platform = ignore_platform - def get_cache(self) -> dict[str, list[Candidate]]: - return PyPIProvider.pypi_resolver_cache + @property + def cache_key(self) -> str: + # ignore platform parameter changes behavior of find_candidates() + if self.ignore_platform: + return f"{self.sdist_server_url}+ignore_platform" + else: + return self.sdist_server_url + + def find_candidates(self, identifier: str) -> Candidates: + return get_project_from_pypi( + identifier, + set(), + self.sdist_server_url, + self.ignore_platform, + ) def validate_candidate( self, @@ -526,23 +591,8 @@ def find_matches( identifier: str, requirements: RequirementsMap, incompatibilities: CandidatesMap, - ) -> typing.Iterable[Candidate]: - candidates = self.get_from_cache(identifier, requirements, incompatibilities) - if not candidates: - # Need to pass the extras to the search, so they - # are added to the candidate at creation - we - # treat candidates as immutable once created. - for candidate in get_project_from_pypi( - identifier, - set(), - self.sdist_server_url, - self.ignore_platform, - ): - if self.validate_candidate( - identifier, requirements, incompatibilities, candidate - ): - candidates.append(candidate) - self.add_to_cache(identifier, candidates) + ) -> Candidates: + candidates = super().find_matches(identifier, requirements, incompatibilities) if not candidates: # Try to construct a meaningful error message that points out the # type(s) of files the resolver has been told it can choose as a @@ -579,18 +629,21 @@ def __call__(self, identifier: str, item: str) -> Version | None: class GenericProvider(BaseProvider): """Lookup package and version by using a callback""" - generic_resolver_cache: typing.ClassVar[dict[str, list[Candidate]]] = defaultdict( - list - ) - def __init__( self, version_source: VersionSource, constraints: Constraints | None = None, req_type: RequirementType | None = None, matcher: MatchFunction | re.Pattern | None = None, + *, + # generic provider does not implement caching + use_resolver_cache: bool = False, ): - super().__init__(constraints=constraints, req_type=req_type) + super().__init__( + constraints=constraints, + req_type=req_type, + use_resolver_cache=use_resolver_cache, + ) self._version_source = version_source if matcher is None: self._match_function = self._default_match_function @@ -624,42 +677,25 @@ def _re_match_function( logger.debug(f"{identifier}: could not parse version from {value}: {err}") return None - def get_cache(self) -> dict[str, list[Candidate]]: - return GenericProvider.generic_resolver_cache + @property + def cache_key(self) -> str: + raise NotImplementedError("GenericProvider does not implement caching") - def find_matches( - self, - identifier: str, - requirements: RequirementsMap, - incompatibilities: CandidatesMap, - ) -> typing.Iterable[Candidate]: - candidates = self.get_from_cache(identifier, requirements, incompatibilities) + def find_candidates(self, identifier) -> Candidates: + candidates: list[Candidate] = [] version: Version | None - - if not candidates: - # Need to pass the extras to the search, so they - # are added to the candidate at creation - we - # treat candidates as immutable once created. - for url, item in self._version_source( - identifier, requirements, incompatibilities - ): - if isinstance(item, Version): - version = item - else: - version = self._match_function(identifier, item) - if version is None: - logger.debug(f"{identifier}: match function ignores {item}") - continue - assert isinstance(version, Version) - version = version - candidate = Candidate(identifier, version, url=url) - if self.validate_candidate( - identifier, requirements, incompatibilities, candidate - ): - candidates.append(candidate) - self.add_to_cache(identifier, candidates) - - return sorted(candidates, key=attrgetter("version"), reverse=True) + for url, item in self._version_source(identifier): + if isinstance(item, Version): + version = item + else: + version = self._match_function(identifier, item) + if version is None: + logger.debug(f"{identifier}: match function ignores {item}") + continue + assert isinstance(version, Version) + version = version + candidates.append(Candidate(identifier, version, url=url)) + return candidates class GitHubTagProvider(GenericProvider): @@ -670,9 +706,6 @@ class GitHubTagProvider(GenericProvider): host = "github.com:443" api_url = "https://api.{self.host}/repos/{self.organization}/{self.repo}/tags" - github_resolver_cache: typing.ClassVar[dict[str, list[Candidate]]] = defaultdict( - list - ) def __init__( self, @@ -680,23 +713,27 @@ def __init__( repo: str, constraints: Constraints | None = None, matcher: MatchFunction | re.Pattern | None = None, + *, + req_type: RequirementType | None = None, + use_resolver_cache: bool = True, ): super().__init__( - version_source=self._find_tags, constraints=constraints, + req_type=req_type, + use_resolver_cache=use_resolver_cache, + version_source=self._find_tags, matcher=matcher, ) self.organization = organization self.repo = repo - def get_cache(self) -> dict[str, list[Candidate]]: - return GitHubTagProvider.github_resolver_cache + @property + def cache_key(self) -> str: + return f"{self.organization}/{self.repo}" def _find_tags( self, identifier: str, - requirements: RequirementsMap, - incompatibilities: CandidatesMap, ) -> Iterable[tuple[str, Version]]: headers = {"accept": "application/vnd.github+json"} @@ -735,20 +772,21 @@ def _find_tags( class GitLabTagProvider(GenericProvider): """Lookup tarball and version from GitLab git tags""" - gitlab_resolver_cache: typing.ClassVar[dict[str, list[Candidate]]] = defaultdict( - list - ) - def __init__( self, project_path: str, server_url: str = "https://gitlab.com", constraints: Constraints | None = None, matcher: MatchFunction | re.Pattern | None = None, + *, + req_type: RequirementType | None = None, + use_resolver_cache: bool = True, ) -> None: super().__init__( - version_source=self._find_tags, constraints=constraints, + req_type=req_type, + use_resolver_cache=use_resolver_cache, + version_source=self._find_tags, matcher=matcher, ) self.server_url = server_url.rstrip("/") @@ -763,14 +801,13 @@ def __init__( f"{self.server_url}/api/v4/projects/{encoded_path}/repository/tags" ) - def get_cache(self) -> dict[str, list[Candidate]]: - return GitLabTagProvider.gitlab_resolver_cache + @property + def cache_key(self) -> str: + return f"{self.server_url}/{self.project_path}" def _find_tags( self, identifier: str, - requirements: RequirementsMap, - incompatibilities: CandidatesMap, ) -> Iterable[tuple[str, Version]]: nexturl: str = self.api_url while nexturl: diff --git a/tests/test_resolver.py b/tests/test_resolver.py index a966a4ee..7634f1ec 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -1,5 +1,5 @@ -import collections import re +import typing import pytest import requests_mock @@ -8,7 +8,6 @@ from packaging.version import Version from fromager import constraints, resolver -from fromager.requirements_file import RequirementType _hydra_core_simple_response = """ @@ -40,12 +39,11 @@ @pytest.fixture(autouse=True) def reset_cache(): - resolver.PyPIProvider.pypi_resolver_cache = collections.defaultdict(list) - resolver.GenericProvider.generic_resolver_cache = collections.defaultdict(list) - resolver.GitHubTagProvider.github_resolver_cache = collections.defaultdict(list) + resolver.BaseProvider.clear_cache() -def test_provider_choose_wheel(): +@pytest.fixture +def pypi_hydra_resolver() -> typing.Generator[resolvelib.AbstractResolver, None, None]: with requests_mock.Mocker() as r: r.get( "https://pypi.org/simple/hydra-core/", @@ -53,86 +51,104 @@ def test_provider_choose_wheel(): ) provider = resolver.PyPIProvider(include_sdists=False) - reporter = resolvelib.BaseReporter() - rslvr = resolvelib.Resolver(provider, reporter) + reporter: resolvelib.BaseReporter = resolvelib.BaseReporter() + yield resolvelib.Resolver(provider, reporter) - result = rslvr.resolve([Requirement("hydra-core")]) - assert "hydra-core" in result.mapping - candidate = result.mapping["hydra-core"] - assert ( - candidate.url - == "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-2-py3-none-any.whl#sha256=fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b" +@pytest.fixture +def gitlab_decile_resolver() -> typing.Generator[ + resolvelib.AbstractResolver, None, None +]: + with requests_mock.Mocker() as r: + r.get( + "https://gitlab.com/api/v4/projects/mirrors%2Fgithub%2Fdecile-team%2Fsubmodlib/repository/tags", + text=_gitlab_submodlib_repo_response, ) - assert str(candidate.version) == "1.3.2" + provider = resolver.GitLabTagProvider( + project_path="mirrors/github/decile-team/submodlib", + server_url="https://gitlab.com", + matcher=re.compile("v(.*)"), # with match object + ) + reporter: resolvelib.BaseReporter = resolvelib.BaseReporter() + yield resolvelib.Resolver(provider, reporter) -def test_provider_cache(): + +@pytest.fixture +def github_fromager_resolver() -> typing.Generator[ + resolvelib.AbstractResolver, None, None +]: + with requests_mock.Mocker() as r: + r.get( + "https://api.github.com:443/repos/python-wheel-build/fromager", + text=_github_fromager_repo_response, + ) + r.get( + "https://api.github.com:443/repos/python-wheel-build/fromager/tags", + text=_github_fromager_tag_response, + ) + + provider = resolver.GitHubTagProvider( + organization="python-wheel-build", repo="fromager" + ) + reporter: resolvelib.BaseReporter = resolvelib.BaseReporter() + yield resolvelib.Resolver(provider, reporter) + + +def test_provider_choose_wheel(): with requests_mock.Mocker() as r: r.get( "https://pypi.org/simple/hydra-core/", text=_hydra_core_simple_response, ) - # fill the cache provider = resolver.PyPIProvider(include_sdists=False) reporter = resolvelib.BaseReporter() rslvr = resolvelib.Resolver(provider, reporter) - result = rslvr.resolve([Requirement("hydra-core<1.3")]) - candidate = result.mapping["hydra-core"] - assert str(candidate.version) == "1.2.2" - assert "hydra-core" in resolver.PyPIProvider.pypi_resolver_cache - assert len(resolver.PyPIProvider.pypi_resolver_cache["hydra-core"]) == 1 - # store a copy of the cache - cache_copy = { - "hydra-core": resolver.PyPIProvider.pypi_resolver_cache["hydra-core"][:] - } + result = rslvr.resolve([Requirement("hydra-core")]) + assert "hydra-core" in result.mapping - # resolve for build requirement should end up with the already seen older version - provider = resolver.PyPIProvider( - include_sdists=False, req_type=RequirementType.BUILD_SDIST - ) - reporter = resolvelib.BaseReporter() - rslvr = resolvelib.Resolver(provider, reporter) - result = rslvr.resolve([Requirement("hydra-core>=1.2")]) candidate = result.mapping["hydra-core"] - assert str(candidate.version) == "1.2.2" - assert "hydra-core" in resolver.PyPIProvider.pypi_resolver_cache - assert len(resolver.PyPIProvider.pypi_resolver_cache["hydra-core"]) == 1 - - # resolve for install requirement should ignore the already seen older version - provider = resolver.PyPIProvider( - include_sdists=False, req_type=RequirementType.INSTALL + assert ( + candidate.url + == "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-2-py3-none-any.whl#sha256=fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b" ) - reporter = resolvelib.BaseReporter() - rslvr = resolvelib.Resolver(provider, reporter) - result = rslvr.resolve([Requirement("hydra-core>=1.2")]) - candidate = result.mapping["hydra-core"] assert str(candidate.version) == "1.3.2" - # have to restore the cache so that 1.3.2 doesn't get picked up from there - resolver.PyPIProvider.pypi_resolver_cache = cache_copy - # double check that the restoration worked - provider = resolver.PyPIProvider( - include_sdists=False, req_type=RequirementType.BUILD_SDIST - ) - reporter = resolvelib.BaseReporter() - rslvr = resolvelib.Resolver(provider, reporter) - result = rslvr.resolve([Requirement("hydra-core>=1.2")]) - candidate = result.mapping["hydra-core"] - assert str(candidate.version) == "1.2.2" +def test_provider_cache_key_pypi(pypi_hydra_resolver) -> None: + req = Requirement("hydra-core<1.3") - # if resolving for build but with different conditions, don't use cache - provider = resolver.PyPIProvider( - include_wheels=False, req_type=RequirementType.BUILD_SDIST - ) - reporter = resolvelib.BaseReporter() - rslvr = resolvelib.Resolver(provider, reporter) - result = rslvr.resolve([Requirement("hydra-core>=1.2")]) - candidate = result.mapping["hydra-core"] - assert str(candidate.version) == "1.3.2" + # fill the cache + provider = pypi_hydra_resolver.provider + assert provider.cache_key == "https://pypi.org/simple/" + req_cache = provider._get_cached_candidates(req.name) + assert req_cache == [] + + result = pypi_hydra_resolver.resolve([req]) + candidate = result.mapping[req.name] + assert str(candidate.version) == "1.2.2" + + resolver_cache = resolver.BaseProvider.resolver_cache + assert req.name in resolver_cache + assert (resolver.PyPIProvider, provider.cache_key) in resolver_cache[req.name] + # mutated in place + assert provider._get_cached_candidates(req.name) is req_cache + assert len(provider._get_cached_candidates(req.name)) == 7 + assert len(req_cache) == 7 + + +def test_provider_cache_key_gitlab(gitlab_decile_resolver) -> None: + provider = gitlab_decile_resolver.provider + assert ( + provider.cache_key == "https://gitlab.com/mirrors/github/decile-team/submodlib" + ) + + +def test_provider_cache_key_github(github_fromager_resolver) -> None: + provider = github_fromager_resolver.provider + assert provider.cache_key == "python-wheel-build/fromager" def test_provider_choose_wheel_prereleases(): @@ -593,7 +609,9 @@ def test_resolve_github(): text=_github_fromager_tag_response, ) - provider = resolver.GitHubTagProvider("python-wheel-build", "fromager") + provider = resolver.GitHubTagProvider( + organization="python-wheel-build", repo="fromager" + ) reporter = resolvelib.BaseReporter() rslvr = resolvelib.Resolver(provider, reporter) @@ -623,7 +641,7 @@ def test_github_constraint_mismatch(): ) provider = resolver.GitHubTagProvider( - "python-wheel-build", "fromager", constraints=constraint + organization="python-wheel-build", repo="fromager", constraints=constraint ) reporter = resolvelib.BaseReporter() rslvr = resolvelib.Resolver(provider, reporter) @@ -646,7 +664,7 @@ def test_github_constraint_match(): ) provider = resolver.GitHubTagProvider( - "python-wheel-build", "fromager", constraints=constraint + organization="python-wheel-build", repo="fromager", constraints=constraint ) reporter = resolvelib.BaseReporter() rslvr = resolvelib.Resolver(provider, reporter) @@ -667,7 +685,7 @@ def test_resolve_generic(): def _versions(*args, **kwds): return [("url", "1.2"), ("url", "1.3"), ("url", "1.4.1")] - provider = resolver.GenericProvider(_versions, None) + provider = resolver.GenericProvider(version_source=_versions) reporter = resolvelib.BaseReporter() rslvr = resolvelib.Resolver(provider, reporter) @@ -677,6 +695,12 @@ def _versions(*args, **kwds): candidate = result.mapping["fromager"] assert str(candidate.version) == "1.4.1" + # generic provider does not use resolver cache + assert not resolver.BaseProvider.resolver_cache + + with pytest.raises(NotImplementedError): + assert provider.cache_key + _gitlab_submodlib_repo_response = """ [