Skip to content

Commit 39ce7d4

Browse files
authored
Support arbitrary equality on arbitrary strings for Specifier and SpecifierSet's filter and contains method. (#954)
1 parent 55335c1 commit 39ce7d4

File tree

2 files changed

+250
-48
lines changed

2 files changed

+250
-48
lines changed

src/packaging/specifiers.py

Lines changed: 56 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -268,34 +268,53 @@ def __init__(self, spec: str = "", prereleases: bool | None = None) -> None:
268268
# Specifier version cache
269269
self._spec_version: tuple[str, Version] | None = None
270270

271-
def _get_spec_version(self, version: str) -> Version:
271+
def _get_spec_version(self, version: str) -> Version | None:
272272
"""One element cache, as only one spec Version is needed per Specifier."""
273273
if self._spec_version is not None and self._spec_version[0] == version:
274274
return self._spec_version[1]
275275

276-
version_specifier = Version(version)
276+
version_specifier = _coerce_version(version)
277+
if version_specifier is None:
278+
return None
279+
277280
self._spec_version = (version, version_specifier)
278281
return version_specifier
279282

283+
def _require_spec_version(self, version: str) -> Version:
284+
"""Get spec version, asserting it's valid (not for === operator).
285+
286+
This method should only be called for operators where version
287+
strings are guaranteed to be valid PEP 440 versions (not ===).
288+
"""
289+
spec_version = self._get_spec_version(version)
290+
assert spec_version is not None
291+
return spec_version
292+
280293
@property
281-
def prereleases(self) -> bool:
294+
def prereleases(self) -> bool | None:
282295
# If there is an explicit prereleases set for this, then we'll just
283296
# blindly use that.
284297
if self._prereleases is not None:
285298
return self._prereleases
286299

287300
# Only the "!=" operator does not imply prereleases when
288301
# the version in the specifier is a prerelease.
289-
operator, version = self._spec
302+
operator, version_str = self._spec
290303
if operator != "!=":
291304
# The == specifier with trailing .* cannot include prereleases
292305
# e.g. "==1.0a1.*" is not valid.
293-
if operator == "==" and version.endswith(".*"):
306+
if operator == "==" and version_str.endswith(".*"):
294307
return False
295308

309+
# "===" can have arbitrary string versions, so we cannot parse
310+
# those, we take prereleases as unknown (None) for those.
311+
version = self._get_spec_version(version_str)
312+
if version is None:
313+
return None
314+
296315
# For all other operators, use the check if spec Version
297316
# object implies pre-releases.
298-
if self._get_spec_version(version).is_prerelease:
317+
if version.is_prerelease:
299318
return True
300319

301320
return False
@@ -356,9 +375,10 @@ def _canonical_spec(self) -> tuple[str, str]:
356375
if operator == "===" or version.endswith(".*"):
357376
return operator, version
358377

378+
spec_version = self._require_spec_version(version)
379+
359380
canonical_version = canonicalize_version(
360-
self._get_spec_version(version),
361-
strip_trailing_zero=(operator != "~="),
381+
spec_version, strip_trailing_zero=(operator != "~=")
362382
)
363383

364384
return operator, canonical_version
@@ -451,7 +471,7 @@ def _compare_equal(self, prospective: Version, spec: str) -> bool:
451471
return shortened_prospective == split_spec
452472
else:
453473
# Convert our spec string into a Version
454-
spec_version = self._get_spec_version(spec)
474+
spec_version = self._require_spec_version(spec)
455475

456476
# If the specifier does not have a local segment, then we want to
457477
# act as if the prospective version also does not have a local
@@ -468,18 +488,18 @@ def _compare_less_than_equal(self, prospective: Version, spec: str) -> bool:
468488
# NB: Local version identifiers are NOT permitted in the version
469489
# specifier, so local version labels can be universally removed from
470490
# the prospective version.
471-
return _public_version(prospective) <= self._get_spec_version(spec)
491+
return _public_version(prospective) <= self._require_spec_version(spec)
472492

473493
def _compare_greater_than_equal(self, prospective: Version, spec: str) -> bool:
474494
# NB: Local version identifiers are NOT permitted in the version
475495
# specifier, so local version labels can be universally removed from
476496
# the prospective version.
477-
return _public_version(prospective) >= self._get_spec_version(spec)
497+
return _public_version(prospective) >= self._require_spec_version(spec)
478498

479499
def _compare_less_than(self, prospective: Version, spec_str: str) -> bool:
480500
# Convert our spec to a Version instance, since we'll want to work with
481501
# it as a version.
482-
spec = self._get_spec_version(spec_str)
502+
spec = self._require_spec_version(spec_str)
483503

484504
# Check to see if the prospective version is less than the spec
485505
# version. If it's not we can short circuit and just return False now
@@ -506,7 +526,7 @@ def _compare_less_than(self, prospective: Version, spec_str: str) -> bool:
506526
def _compare_greater_than(self, prospective: Version, spec_str: str) -> bool:
507527
# Convert our spec to a Version instance, since we'll want to work with
508528
# it as a version.
509-
spec = self._get_spec_version(spec_str)
529+
spec = self._require_spec_version(spec_str)
510530

511531
# Check to see if the prospective version is greater than the spec
512532
# version. If it's not we can short circuit and just return False now
@@ -537,7 +557,7 @@ def _compare_greater_than(self, prospective: Version, spec_str: str) -> bool:
537557
# same version in the spec.
538558
return True
539559

540-
def _compare_arbitrary(self, prospective: Version, spec: str) -> bool:
560+
def _compare_arbitrary(self, prospective: Version | str, spec: str) -> bool:
541561
return str(prospective).lower() == str(spec).lower()
542562

543563
def __contains__(self, item: str | Version) -> bool:
@@ -627,9 +647,12 @@ def filter(
627647
for version in iterable:
628648
parsed_version = _coerce_version(version)
629649
if parsed_version is None:
630-
continue
631-
632-
if operator_callable(parsed_version, self.version):
650+
# === operator can match arbitrary (non-version) strings
651+
if self.operator == "===" and self._compare_arbitrary(
652+
version, self.version
653+
):
654+
yield version
655+
elif operator_callable(parsed_version, self.version):
633656
# If it's not a prerelease or prereleases are allowed, yield it directly
634657
if not parsed_version.is_prerelease or include_prereleases:
635658
found_non_prereleases = True
@@ -944,13 +967,12 @@ def contains(
944967
True
945968
"""
946969
version = _coerce_version(item)
947-
if version is None:
948-
return False
949970

950-
if installed and version.is_prerelease:
971+
if version is not None and installed and version.is_prerelease:
951972
prereleases = True
952973

953-
return bool(list(self.filter([version], prereleases=prereleases)))
974+
check_item = item if version is None else version
975+
return bool(list(self.filter([check_item], prereleases=prereleases)))
954976

955977
def filter(
956978
self, iterable: Iterable[UnparsedVersionVar], prereleases: bool | None = None
@@ -1019,22 +1041,28 @@ def filter(
10191041
return (
10201042
item
10211043
for item in iterable
1022-
if (version := _coerce_version(item)) is not None
1023-
and not version.is_prerelease
1044+
if (version := _coerce_version(item)) is None
1045+
or not version.is_prerelease
10241046
)
10251047

10261048
# Finally if prereleases is None, apply PEP 440 logic:
10271049
# exclude prereleases unless there are no final releases that matched.
1028-
filtered: list[UnparsedVersionVar] = []
1050+
filtered_items: list[UnparsedVersionVar] = []
10291051
found_prereleases: list[UnparsedVersionVar] = []
1052+
found_final_release = False
10301053

10311054
for item in iterable:
10321055
parsed_version = _coerce_version(item)
1056+
# Arbitrary strings are always included as it is not
1057+
# possible to determine if they are prereleases,
1058+
# and they have already passed all specifiers.
10331059
if parsed_version is None:
1034-
continue
1035-
if parsed_version.is_prerelease:
1060+
filtered_items.append(item)
1061+
found_prereleases.append(item)
1062+
elif parsed_version.is_prerelease:
10361063
found_prereleases.append(item)
10371064
else:
1038-
filtered.append(item)
1065+
filtered_items.append(item)
1066+
found_final_release = True
10391067

1040-
return iter(filtered if filtered else found_prereleases)
1068+
return iter(filtered_items if found_final_release else found_prereleases)

0 commit comments

Comments
 (0)