Skip to content
84 changes: 56 additions & 28 deletions src/packaging/specifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,34 +268,53 @@ def __init__(self, spec: str = "", prereleases: bool | None = None) -> None:
# Specifier version cache
self._spec_version: tuple[str, Version] | None = None

def _get_spec_version(self, version: str) -> Version:
def _get_spec_version(self, version: str) -> Version | None:
"""One element cache, as only one spec Version is needed per Specifier."""
if self._spec_version is not None and self._spec_version[0] == version:
return self._spec_version[1]

version_specifier = Version(version)
version_specifier = _coerce_version(version)
if version_specifier is None:
return None

self._spec_version = (version, version_specifier)
return version_specifier

def _require_spec_version(self, version: str) -> Version:
"""Get spec version, asserting it's valid (not for === operator).

This method should only be called for operators where version
strings are guaranteed to be valid PEP 440 versions (not ===).
"""
spec_version = self._get_spec_version(version)
assert spec_version is not None
return spec_version

@property
def prereleases(self) -> bool:
def prereleases(self) -> bool | None:
# If there is an explicit prereleases set for this, then we'll just
# blindly use that.
if self._prereleases is not None:
return self._prereleases

# Only the "!=" operator does not imply prereleases when
# the version in the specifier is a prerelease.
operator, version = self._spec
operator, item = self._spec
if operator != "!=":
# The == specifier with trailing .* cannot include prereleases
# e.g. "==1.0a1.*" is not valid.
if operator == "==" and version.endswith(".*"):
if operator == "==" and item.endswith(".*"):
return False

# "===" can have arbitrary string versions, so we cannot
# parse those, we take prereleases as False for those.
version = self._get_spec_version(item)
if version is None:
return None

# For all other operators, use the check if spec Version
# object implies pre-releases.
if self._get_spec_version(version).is_prerelease:
if version.is_prerelease:
return True

return False
Expand Down Expand Up @@ -356,9 +375,10 @@ def _canonical_spec(self) -> tuple[str, str]:
if operator == "===" or version.endswith(".*"):
return operator, version

spec_version = self._require_spec_version(version)

canonical_version = canonicalize_version(
self._get_spec_version(version),
strip_trailing_zero=(operator != "~="),
spec_version, strip_trailing_zero=(operator != "~=")
)

return operator, canonical_version
Expand Down Expand Up @@ -451,7 +471,7 @@ def _compare_equal(self, prospective: Version, spec: str) -> bool:
return shortened_prospective == split_spec
else:
# Convert our spec string into a Version
spec_version = self._get_spec_version(spec)
spec_version = self._require_spec_version(spec)

# If the specifier does not have a local segment, then we want to
# act as if the prospective version also does not have a local
Expand All @@ -468,18 +488,18 @@ def _compare_less_than_equal(self, prospective: Version, spec: str) -> bool:
# NB: Local version identifiers are NOT permitted in the version
# specifier, so local version labels can be universally removed from
# the prospective version.
return _public_version(prospective) <= self._get_spec_version(spec)
return _public_version(prospective) <= self._require_spec_version(spec)

def _compare_greater_than_equal(self, prospective: Version, spec: str) -> bool:
# NB: Local version identifiers are NOT permitted in the version
# specifier, so local version labels can be universally removed from
# the prospective version.
return _public_version(prospective) >= self._get_spec_version(spec)
return _public_version(prospective) >= self._require_spec_version(spec)

def _compare_less_than(self, prospective: Version, spec_str: str) -> bool:
# Convert our spec to a Version instance, since we'll want to work with
# it as a version.
spec = self._get_spec_version(spec_str)
spec = self._require_spec_version(spec_str)

# Check to see if the prospective version is less than the spec
# version. If it's not we can short circuit and just return False now
Expand All @@ -506,7 +526,7 @@ def _compare_less_than(self, prospective: Version, spec_str: str) -> bool:
def _compare_greater_than(self, prospective: Version, spec_str: str) -> bool:
# Convert our spec to a Version instance, since we'll want to work with
# it as a version.
spec = self._get_spec_version(spec_str)
spec = self._require_spec_version(spec_str)

# Check to see if the prospective version is greater than the spec
# version. If it's not we can short circuit and just return False now
Expand Down Expand Up @@ -537,7 +557,7 @@ def _compare_greater_than(self, prospective: Version, spec_str: str) -> bool:
# same version in the spec.
return True

def _compare_arbitrary(self, prospective: Version, spec: str) -> bool:
def _compare_arbitrary(self, prospective: Version | str, spec: str) -> bool:
return str(prospective).lower() == str(spec).lower()

def __contains__(self, item: str | Version) -> bool:
Expand Down Expand Up @@ -627,9 +647,12 @@ def filter(
for version in iterable:
parsed_version = _coerce_version(version)
if parsed_version is None:
continue

if operator_callable(parsed_version, self.version):
# === operator can match arbitrary (non-version) strings
if self.operator == "===" and self._compare_arbitrary(
version, self.version
):
yield version
elif operator_callable(parsed_version, self.version):
# If it's not a prerelease or prereleases are allowed, yield it directly
if not parsed_version.is_prerelease or include_prereleases:
found_non_prereleases = True
Expand Down Expand Up @@ -944,13 +967,12 @@ def contains(
True
"""
version = _coerce_version(item)
if version is None:
return False

if installed and version.is_prerelease:
if version is not None and installed and version.is_prerelease:
prereleases = True

return bool(list(self.filter([version], prereleases=prereleases)))
check_item = item if version is None else version
return bool(list(self.filter([check_item], prereleases=prereleases)))

def filter(
self, iterable: Iterable[UnparsedVersionVar], prereleases: bool | None = None
Expand Down Expand Up @@ -1019,22 +1041,28 @@ def filter(
return (
item
for item in iterable
if (version := _coerce_version(item)) is not None
and not version.is_prerelease
if (version := _coerce_version(item)) is None
or not version.is_prerelease
)

# Finally if prereleases is None, apply PEP 440 logic:
# exclude prereleases unless there are no final releases that matched.
filtered: list[UnparsedVersionVar] = []
filtered_items: list[UnparsedVersionVar] = []
found_prereleases: list[UnparsedVersionVar] = []
found_final_release = False

for item in iterable:
parsed_version = _coerce_version(item)
# Arbitrary strings are always included as it is not
# possible to determine if they are prereleases,
# and they have already passed all specifiers.
if parsed_version is None:
continue
if parsed_version.is_prerelease:
filtered_items.append(item)
found_prereleases.append(item)
elif parsed_version.is_prerelease:
found_prereleases.append(item)
else:
filtered.append(item)
filtered_items.append(item)
found_final_release = True

return iter(filtered if filtered else found_prereleases)
return iter(filtered_items if found_final_release else found_prereleases)
Loading
Loading