From 9e71b5cb8e839af7716b4f3db0a1263f96d8e81b Mon Sep 17 00:00:00 2001 From: Christian Heimes Date: Fri, 26 Sep 2025 08:41:49 +0200 Subject: [PATCH] fix(resolver): redesign resolver cache Redesign the resolver cache to address multiple issues. 1. The cache now holds all candidates from the provider. Before it only stored candidates that also fulfilled the requirement and constraints. 2. The cache now takes sdist url, GitHub repo, and GitLab project into account. Before a cache for local PyPI and pypi.org were mixed. 3. GenericProvider no longer caches candidates. There is no way to construct a good cache key. Bootstrapper is using GenericProvider in way that does not benefit from caching either. 4. There is just one global cache object for all providers. The new design makes it easier to clear all caches or just the cache for a single identifier. 5. The custom logic for each provider class is now in `find_candidates` function. The method just has to return an iterable of candidates. The caching logic is handled by the rest of the code. 6. Consumers can now opt out of caching with `use_resolver_cache=False` argument. 7. All resolver classes now require keyword arguments. The base provider no longer takes PyPI-only arguments like `include_sdists`. Fixes: #766 Signed-off-by: Christian Heimes --- src/fromager/bootstrapper.py | 4 +- src/fromager/resolver.py | 277 ++++++++++++++++++++--------------- tests/test_resolver.py | 160 +++++++++++--------- 3 files changed, 252 insertions(+), 189 deletions(-) diff --git a/src/fromager/bootstrapper.py b/src/fromager/bootstrapper.py index ae0c17da..7635272e 100644 --- a/src/fromager/bootstrapper.py +++ b/src/fromager/bootstrapper.py @@ -913,9 +913,11 @@ def _resolve_from_version_source( return None try: # no need to pass req type to enable caching since we are already using the graph as our cache + # do not cache candidates provider = resolver.GenericProvider( - version_source=lambda x, y, z: version_source, + version_source=lambda identifier: version_source, constraints=self.ctx.constraints, + use_resolver_cache=False, ) return resolver.resolve_from_provider(provider, req) except Exception as err: diff --git a/src/fromager/resolver.py b/src/fromager/resolver.py index 187d7f74..69bbcaa4 100644 --- a/src/fromager/resolver.py +++ b/src/fromager/resolver.py @@ -10,7 +10,6 @@ import os import re import typing -from collections import defaultdict from collections.abc import Iterable from operator import attrgetter from platform import python_version @@ -180,7 +179,7 @@ def get_project_from_pypi( extras: typing.Iterable[str], sdist_server_url: str, ignore_platform: bool = False, -) -> typing.Iterable[Candidate]: +) -> Candidates: """Return candidates created from the project name and extras.""" found_candidates: set[str] = set() ignored_candidates: set[str] = set() @@ -330,34 +329,62 @@ def get_project_from_pypi( RequirementsMap: typing.TypeAlias = typing.Mapping[str, typing.Iterable[Requirement]] -CandidatesMap: typing.TypeAlias = typing.Mapping[str, typing.Iterable[Candidate]] +Candidates: typing.TypeAlias = typing.Iterable[Candidate] +CandidatesMap: typing.TypeAlias = typing.Mapping[str, Candidates] +# {identifier: [cls, cachekey]: list[candidates]}} +ResolverCache: typing.TypeAlias = dict[ + str, dict[tuple[type[ExtrasProvider], str], list[Candidate]] +] VersionSource: typing.TypeAlias = typing.Callable[ - [str, RequirementsMap, CandidatesMap], + [str], typing.Iterable[tuple[str, str | Version]], ] class BaseProvider(ExtrasProvider): + resolver_cache: typing.ClassVar[ResolverCache] = {} + def __init__( self, - include_sdists: bool = True, - include_wheels: bool = True, - sdist_server_url: str = "https://pypi.org/simple/", + *, constraints: Constraints | None = None, req_type: RequirementType | None = None, - ignore_platform: bool = False, + use_resolver_cache: bool = True, ): super().__init__() - self.include_sdists = include_sdists - self.include_wheels = include_wheels - self.sdist_server_url = sdist_server_url self.constraints = constraints or Constraints() self.req_type = req_type - self.ignore_platform = ignore_platform + self.use_cache_candidates = use_resolver_cache + + @property + def cache_key(self) -> str: + """Return cache key for the provider + + The cache key must be unique for each provider configuration, e.g. + PyPI URL, GitHub org + repo, ... + """ + raise NotImplementedError() + + def find_candidates(self, identifier: str) -> Candidates: + """Find unfiltered candidates""" + raise NotImplementedError() def identify(self, requirement_or_candidate: Requirement | Candidate) -> str: return canonicalize_name(requirement_or_candidate.name) + @classmethod + def clear_cache(cls, identifier: str | None = None) -> None: + """Clear global resolver cache + + ``None`` clears all caches, an ``identifier`` string clears the + cache for an identifier. Raises :exc:`KeyError` for unknown + identifiers. + """ + if identifier is None: + cls.resolver_cache.clear() + else: + cls.resolver_cache.pop(canonicalize_name(identifier)) + def get_extras_for( self, requirement_or_candidate: Requirement | Candidate, @@ -391,31 +418,6 @@ def validate_candidate( return True return False - def get_cache(self) -> dict[str, list[Candidate]]: - raise NotImplementedError() - - def get_from_cache( - self, - identifier: str, - requirements: RequirementsMap, - incompatibilities: CandidatesMap, - ) -> list[Candidate]: - cache = self.get_cache() - # we only want caching for build reqs because for install time reqs we always want to get the latest version - # we can't guarantee that the latest version is available in the cache so install time reqs cannot use the cache - if self.req_type is None or not self.req_type.is_build_requirement: - return [] - return [ - c - for c in cache[identifier] - if self.validate_candidate(identifier, requirements, incompatibilities, c) - ] - - def add_to_cache(self, identifier: str, candidates: list[Candidate]) -> None: - # we can add candidates to cache even for install type reqs because build time reqs are - # allowed to use candidates seen when we were resolving the same req as an install type - self.get_cache()[identifier].extend(candidates) - def get_preference( self, identifier: str, @@ -459,20 +461,67 @@ def get_dependencies(self, candidate: Candidate) -> list[Requirement]: # return candidate.dependencies return [] + def _get_cached_candidates(self, identifier: str) -> list[Candidate]: + """Get list of cached candidates for identifier and provider + + The method always returns a list. If the cache did not have an entry + before, a new empty list is stored in the cache and returned to the + caller. The caller can mutate the list in place to update the cache. + """ + cls = type(self) + provider_cache = cls.resolver_cache.setdefault(identifier, {}) + candidate_cache = provider_cache.setdefault((cls, self.cache_key), []) + return candidate_cache + + def _find_cached_candidates(self, identifier: str) -> Candidates: + """Find candidates with caching""" + if self.use_cache_candidates: + cached_candidates = self._get_cached_candidates(identifier) + if cached_candidates: + logger.debug( + "%s: use %i cached candidates", + identifier, + len(cached_candidates), + ) + return cached_candidates + candidates = list(self.find_candidates(identifier)) + if self.use_cache_candidates: + # mutate list object in-place + cached_candidates[:] = candidates + logger.debug( + "%s: cache %i unfiltered candidates", + identifier, + len(candidates), + ) + else: + logger.debug( + "%s: got %i unfiltered candidates, ignoring cache", + identifier, + len(candidates), + ) + return candidates + def find_matches( self, identifier: str, requirements: RequirementsMap, incompatibilities: CandidatesMap, - ) -> typing.Iterable[Candidate]: - raise NotImplementedError() + ) -> Candidates: + """Find matching candidates, sorted by version and build tag""" + unfiltered_candidates = self._find_cached_candidates(identifier) + candidates = [ + candidate + for candidate in unfiltered_candidates + if self.validate_candidate( + identifier, requirements, incompatibilities, candidate + ) + ] + return sorted(candidates, key=attrgetter("version", "build_tag"), reverse=True) class PyPIProvider(BaseProvider): """Lookup package and versions from a simple Python index (PyPI)""" - pypi_resolver_cache: typing.ClassVar[dict[str, list[Candidate]]] = defaultdict(list) - def __init__( self, include_sdists: bool = True, @@ -481,18 +530,34 @@ def __init__( constraints: Constraints | None = None, req_type: RequirementType | None = None, ignore_platform: bool = False, + *, + use_resolver_cache: bool = True, ): super().__init__( - include_sdists=include_sdists, - include_wheels=include_wheels, - sdist_server_url=sdist_server_url, constraints=constraints, req_type=req_type, - ignore_platform=ignore_platform, + use_resolver_cache=use_resolver_cache, ) + self.include_sdists = include_sdists + self.include_wheels = include_wheels + self.sdist_server_url = sdist_server_url + self.ignore_platform = ignore_platform - def get_cache(self) -> dict[str, list[Candidate]]: - return PyPIProvider.pypi_resolver_cache + @property + def cache_key(self) -> str: + # ignore platform parameter changes behavior of find_candidates() + if self.ignore_platform: + return f"{self.sdist_server_url}+ignore_platform" + else: + return self.sdist_server_url + + def find_candidates(self, identifier: str) -> Candidates: + return get_project_from_pypi( + identifier, + set(), + self.sdist_server_url, + self.ignore_platform, + ) def validate_candidate( self, @@ -526,23 +591,8 @@ def find_matches( identifier: str, requirements: RequirementsMap, incompatibilities: CandidatesMap, - ) -> typing.Iterable[Candidate]: - candidates = self.get_from_cache(identifier, requirements, incompatibilities) - if not candidates: - # Need to pass the extras to the search, so they - # are added to the candidate at creation - we - # treat candidates as immutable once created. - for candidate in get_project_from_pypi( - identifier, - set(), - self.sdist_server_url, - self.ignore_platform, - ): - if self.validate_candidate( - identifier, requirements, incompatibilities, candidate - ): - candidates.append(candidate) - self.add_to_cache(identifier, candidates) + ) -> Candidates: + candidates = super().find_matches(identifier, requirements, incompatibilities) if not candidates: # Try to construct a meaningful error message that points out the # type(s) of files the resolver has been told it can choose as a @@ -579,18 +629,21 @@ def __call__(self, identifier: str, item: str) -> Version | None: class GenericProvider(BaseProvider): """Lookup package and version by using a callback""" - generic_resolver_cache: typing.ClassVar[dict[str, list[Candidate]]] = defaultdict( - list - ) - def __init__( self, version_source: VersionSource, constraints: Constraints | None = None, req_type: RequirementType | None = None, matcher: MatchFunction | re.Pattern | None = None, + *, + # generic provider does not implement caching + use_resolver_cache: bool = False, ): - super().__init__(constraints=constraints, req_type=req_type) + super().__init__( + constraints=constraints, + req_type=req_type, + use_resolver_cache=use_resolver_cache, + ) self._version_source = version_source if matcher is None: self._match_function = self._default_match_function @@ -624,42 +677,25 @@ def _re_match_function( logger.debug(f"{identifier}: could not parse version from {value}: {err}") return None - def get_cache(self) -> dict[str, list[Candidate]]: - return GenericProvider.generic_resolver_cache + @property + def cache_key(self) -> str: + raise NotImplementedError("GenericProvider does not implement caching") - def find_matches( - self, - identifier: str, - requirements: RequirementsMap, - incompatibilities: CandidatesMap, - ) -> typing.Iterable[Candidate]: - candidates = self.get_from_cache(identifier, requirements, incompatibilities) + def find_candidates(self, identifier) -> Candidates: + candidates: list[Candidate] = [] version: Version | None - - if not candidates: - # Need to pass the extras to the search, so they - # are added to the candidate at creation - we - # treat candidates as immutable once created. - for url, item in self._version_source( - identifier, requirements, incompatibilities - ): - if isinstance(item, Version): - version = item - else: - version = self._match_function(identifier, item) - if version is None: - logger.debug(f"{identifier}: match function ignores {item}") - continue - assert isinstance(version, Version) - version = version - candidate = Candidate(identifier, version, url=url) - if self.validate_candidate( - identifier, requirements, incompatibilities, candidate - ): - candidates.append(candidate) - self.add_to_cache(identifier, candidates) - - return sorted(candidates, key=attrgetter("version"), reverse=True) + for url, item in self._version_source(identifier): + if isinstance(item, Version): + version = item + else: + version = self._match_function(identifier, item) + if version is None: + logger.debug(f"{identifier}: match function ignores {item}") + continue + assert isinstance(version, Version) + version = version + candidates.append(Candidate(identifier, version, url=url)) + return candidates class GitHubTagProvider(GenericProvider): @@ -670,9 +706,6 @@ class GitHubTagProvider(GenericProvider): host = "github.com:443" api_url = "https://api.{self.host}/repos/{self.organization}/{self.repo}/tags" - github_resolver_cache: typing.ClassVar[dict[str, list[Candidate]]] = defaultdict( - list - ) def __init__( self, @@ -680,23 +713,27 @@ def __init__( repo: str, constraints: Constraints | None = None, matcher: MatchFunction | re.Pattern | None = None, + *, + req_type: RequirementType | None = None, + use_resolver_cache: bool = True, ): super().__init__( - version_source=self._find_tags, constraints=constraints, + req_type=req_type, + use_resolver_cache=use_resolver_cache, + version_source=self._find_tags, matcher=matcher, ) self.organization = organization self.repo = repo - def get_cache(self) -> dict[str, list[Candidate]]: - return GitHubTagProvider.github_resolver_cache + @property + def cache_key(self) -> str: + return f"{self.organization}/{self.repo}" def _find_tags( self, identifier: str, - requirements: RequirementsMap, - incompatibilities: CandidatesMap, ) -> Iterable[tuple[str, Version]]: headers = {"accept": "application/vnd.github+json"} @@ -735,20 +772,21 @@ def _find_tags( class GitLabTagProvider(GenericProvider): """Lookup tarball and version from GitLab git tags""" - gitlab_resolver_cache: typing.ClassVar[dict[str, list[Candidate]]] = defaultdict( - list - ) - def __init__( self, project_path: str, server_url: str = "https://gitlab.com", constraints: Constraints | None = None, matcher: MatchFunction | re.Pattern | None = None, + *, + req_type: RequirementType | None = None, + use_resolver_cache: bool = True, ) -> None: super().__init__( - version_source=self._find_tags, constraints=constraints, + req_type=req_type, + use_resolver_cache=use_resolver_cache, + version_source=self._find_tags, matcher=matcher, ) self.server_url = server_url.rstrip("/") @@ -763,14 +801,13 @@ def __init__( f"{self.server_url}/api/v4/projects/{encoded_path}/repository/tags" ) - def get_cache(self) -> dict[str, list[Candidate]]: - return GitLabTagProvider.gitlab_resolver_cache + @property + def cache_key(self) -> str: + return f"{self.server_url}/{self.project_path}" def _find_tags( self, identifier: str, - requirements: RequirementsMap, - incompatibilities: CandidatesMap, ) -> Iterable[tuple[str, Version]]: nexturl: str = self.api_url while nexturl: diff --git a/tests/test_resolver.py b/tests/test_resolver.py index a966a4ee..7634f1ec 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -1,5 +1,5 @@ -import collections import re +import typing import pytest import requests_mock @@ -8,7 +8,6 @@ from packaging.version import Version from fromager import constraints, resolver -from fromager.requirements_file import RequirementType _hydra_core_simple_response = """ @@ -40,12 +39,11 @@ @pytest.fixture(autouse=True) def reset_cache(): - resolver.PyPIProvider.pypi_resolver_cache = collections.defaultdict(list) - resolver.GenericProvider.generic_resolver_cache = collections.defaultdict(list) - resolver.GitHubTagProvider.github_resolver_cache = collections.defaultdict(list) + resolver.BaseProvider.clear_cache() -def test_provider_choose_wheel(): +@pytest.fixture +def pypi_hydra_resolver() -> typing.Generator[resolvelib.AbstractResolver, None, None]: with requests_mock.Mocker() as r: r.get( "https://pypi.org/simple/hydra-core/", @@ -53,86 +51,104 @@ def test_provider_choose_wheel(): ) provider = resolver.PyPIProvider(include_sdists=False) - reporter = resolvelib.BaseReporter() - rslvr = resolvelib.Resolver(provider, reporter) + reporter: resolvelib.BaseReporter = resolvelib.BaseReporter() + yield resolvelib.Resolver(provider, reporter) - result = rslvr.resolve([Requirement("hydra-core")]) - assert "hydra-core" in result.mapping - candidate = result.mapping["hydra-core"] - assert ( - candidate.url - == "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-2-py3-none-any.whl#sha256=fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b" +@pytest.fixture +def gitlab_decile_resolver() -> typing.Generator[ + resolvelib.AbstractResolver, None, None +]: + with requests_mock.Mocker() as r: + r.get( + "https://gitlab.com/api/v4/projects/mirrors%2Fgithub%2Fdecile-team%2Fsubmodlib/repository/tags", + text=_gitlab_submodlib_repo_response, ) - assert str(candidate.version) == "1.3.2" + provider = resolver.GitLabTagProvider( + project_path="mirrors/github/decile-team/submodlib", + server_url="https://gitlab.com", + matcher=re.compile("v(.*)"), # with match object + ) + reporter: resolvelib.BaseReporter = resolvelib.BaseReporter() + yield resolvelib.Resolver(provider, reporter) -def test_provider_cache(): + +@pytest.fixture +def github_fromager_resolver() -> typing.Generator[ + resolvelib.AbstractResolver, None, None +]: + with requests_mock.Mocker() as r: + r.get( + "https://api.github.com:443/repos/python-wheel-build/fromager", + text=_github_fromager_repo_response, + ) + r.get( + "https://api.github.com:443/repos/python-wheel-build/fromager/tags", + text=_github_fromager_tag_response, + ) + + provider = resolver.GitHubTagProvider( + organization="python-wheel-build", repo="fromager" + ) + reporter: resolvelib.BaseReporter = resolvelib.BaseReporter() + yield resolvelib.Resolver(provider, reporter) + + +def test_provider_choose_wheel(): with requests_mock.Mocker() as r: r.get( "https://pypi.org/simple/hydra-core/", text=_hydra_core_simple_response, ) - # fill the cache provider = resolver.PyPIProvider(include_sdists=False) reporter = resolvelib.BaseReporter() rslvr = resolvelib.Resolver(provider, reporter) - result = rslvr.resolve([Requirement("hydra-core<1.3")]) - candidate = result.mapping["hydra-core"] - assert str(candidate.version) == "1.2.2" - assert "hydra-core" in resolver.PyPIProvider.pypi_resolver_cache - assert len(resolver.PyPIProvider.pypi_resolver_cache["hydra-core"]) == 1 - # store a copy of the cache - cache_copy = { - "hydra-core": resolver.PyPIProvider.pypi_resolver_cache["hydra-core"][:] - } + result = rslvr.resolve([Requirement("hydra-core")]) + assert "hydra-core" in result.mapping - # resolve for build requirement should end up with the already seen older version - provider = resolver.PyPIProvider( - include_sdists=False, req_type=RequirementType.BUILD_SDIST - ) - reporter = resolvelib.BaseReporter() - rslvr = resolvelib.Resolver(provider, reporter) - result = rslvr.resolve([Requirement("hydra-core>=1.2")]) candidate = result.mapping["hydra-core"] - assert str(candidate.version) == "1.2.2" - assert "hydra-core" in resolver.PyPIProvider.pypi_resolver_cache - assert len(resolver.PyPIProvider.pypi_resolver_cache["hydra-core"]) == 1 - - # resolve for install requirement should ignore the already seen older version - provider = resolver.PyPIProvider( - include_sdists=False, req_type=RequirementType.INSTALL + assert ( + candidate.url + == "https://files.pythonhosted.org/packages/c6/50/e0edd38dcd63fb26a8547f13d28f7a008bc4a3fd4eb4ff030673f22ad41a/hydra_core-1.3.2-2-py3-none-any.whl#sha256=fa0238a9e31df3373b35b0bfb672c34cc92718d21f81311d8996a16de1141d8b" ) - reporter = resolvelib.BaseReporter() - rslvr = resolvelib.Resolver(provider, reporter) - result = rslvr.resolve([Requirement("hydra-core>=1.2")]) - candidate = result.mapping["hydra-core"] assert str(candidate.version) == "1.3.2" - # have to restore the cache so that 1.3.2 doesn't get picked up from there - resolver.PyPIProvider.pypi_resolver_cache = cache_copy - # double check that the restoration worked - provider = resolver.PyPIProvider( - include_sdists=False, req_type=RequirementType.BUILD_SDIST - ) - reporter = resolvelib.BaseReporter() - rslvr = resolvelib.Resolver(provider, reporter) - result = rslvr.resolve([Requirement("hydra-core>=1.2")]) - candidate = result.mapping["hydra-core"] - assert str(candidate.version) == "1.2.2" +def test_provider_cache_key_pypi(pypi_hydra_resolver) -> None: + req = Requirement("hydra-core<1.3") - # if resolving for build but with different conditions, don't use cache - provider = resolver.PyPIProvider( - include_wheels=False, req_type=RequirementType.BUILD_SDIST - ) - reporter = resolvelib.BaseReporter() - rslvr = resolvelib.Resolver(provider, reporter) - result = rslvr.resolve([Requirement("hydra-core>=1.2")]) - candidate = result.mapping["hydra-core"] - assert str(candidate.version) == "1.3.2" + # fill the cache + provider = pypi_hydra_resolver.provider + assert provider.cache_key == "https://pypi.org/simple/" + req_cache = provider._get_cached_candidates(req.name) + assert req_cache == [] + + result = pypi_hydra_resolver.resolve([req]) + candidate = result.mapping[req.name] + assert str(candidate.version) == "1.2.2" + + resolver_cache = resolver.BaseProvider.resolver_cache + assert req.name in resolver_cache + assert (resolver.PyPIProvider, provider.cache_key) in resolver_cache[req.name] + # mutated in place + assert provider._get_cached_candidates(req.name) is req_cache + assert len(provider._get_cached_candidates(req.name)) == 7 + assert len(req_cache) == 7 + + +def test_provider_cache_key_gitlab(gitlab_decile_resolver) -> None: + provider = gitlab_decile_resolver.provider + assert ( + provider.cache_key == "https://gitlab.com/mirrors/github/decile-team/submodlib" + ) + + +def test_provider_cache_key_github(github_fromager_resolver) -> None: + provider = github_fromager_resolver.provider + assert provider.cache_key == "python-wheel-build/fromager" def test_provider_choose_wheel_prereleases(): @@ -593,7 +609,9 @@ def test_resolve_github(): text=_github_fromager_tag_response, ) - provider = resolver.GitHubTagProvider("python-wheel-build", "fromager") + provider = resolver.GitHubTagProvider( + organization="python-wheel-build", repo="fromager" + ) reporter = resolvelib.BaseReporter() rslvr = resolvelib.Resolver(provider, reporter) @@ -623,7 +641,7 @@ def test_github_constraint_mismatch(): ) provider = resolver.GitHubTagProvider( - "python-wheel-build", "fromager", constraints=constraint + organization="python-wheel-build", repo="fromager", constraints=constraint ) reporter = resolvelib.BaseReporter() rslvr = resolvelib.Resolver(provider, reporter) @@ -646,7 +664,7 @@ def test_github_constraint_match(): ) provider = resolver.GitHubTagProvider( - "python-wheel-build", "fromager", constraints=constraint + organization="python-wheel-build", repo="fromager", constraints=constraint ) reporter = resolvelib.BaseReporter() rslvr = resolvelib.Resolver(provider, reporter) @@ -667,7 +685,7 @@ def test_resolve_generic(): def _versions(*args, **kwds): return [("url", "1.2"), ("url", "1.3"), ("url", "1.4.1")] - provider = resolver.GenericProvider(_versions, None) + provider = resolver.GenericProvider(version_source=_versions) reporter = resolvelib.BaseReporter() rslvr = resolvelib.Resolver(provider, reporter) @@ -677,6 +695,12 @@ def _versions(*args, **kwds): candidate = result.mapping["fromager"] assert str(candidate.version) == "1.4.1" + # generic provider does not use resolver cache + assert not resolver.BaseProvider.resolver_cache + + with pytest.raises(NotImplementedError): + assert provider.cache_key + _gitlab_submodlib_repo_response = """ [