Skip to content

Commit 5bbfe0f

Browse files
committed
Support arbitrary equality on arbitrary strings
1 parent f61a88b commit 5bbfe0f

File tree

2 files changed

+163
-32
lines changed

2 files changed

+163
-32
lines changed

src/packaging/specifiers.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -279,24 +279,30 @@ def _get_spec_version(self, version: str) -> Version:
279279
return version_specifier
280280

281281
@property
282-
def prereleases(self) -> bool:
282+
def prereleases(self) -> bool | None:
283283
# If there is an explicit prereleases set for this, then we'll just
284284
# blindly use that.
285285
if self._prereleases is not None:
286286
return self._prereleases
287287

288288
# Only the "!=" operator does not imply prereleases when
289289
# the version in the specifier is a prerelease.
290-
operator, version = self._spec
290+
operator, item = self._spec
291291
if operator != "!=":
292292
# The == specifier can include a trailing .*, if it does we
293293
# want to remove before parsing.
294-
if operator == "==" and version.endswith(".*"):
295-
version = version[:-2]
294+
if operator == "==" and item.endswith(".*"):
295+
item = item[:-2]
296+
297+
# "===" can have arbitrary string versions, so we cannot
298+
# parse those, we take prereleases as False for those.
299+
version = _coerce_version(item)
300+
if version is None:
301+
return None
296302

297303
# Parse the version, and if it is a pre-release than this
298304
# specifier allows pre-releases.
299-
if Version(version).is_prerelease:
305+
if version.is_prerelease:
300306
return True
301307

302308
return False
@@ -538,7 +544,7 @@ def _compare_greater_than(self, prospective: Version, spec_str: str) -> bool:
538544
# same version in the spec.
539545
return True
540546

541-
def _compare_arbitrary(self, prospective: Version, spec: str) -> bool:
547+
def _compare_arbitrary(self, prospective: Version | str, spec: str) -> bool:
542548
return str(prospective).lower() == str(spec).lower()
543549

544550
def __contains__(self, item: str | Version) -> bool:
@@ -628,9 +634,15 @@ def filter(
628634
for version in iterable:
629635
parsed_version = _coerce_version(version)
630636
if parsed_version is None:
631-
continue
632-
633-
if operator_callable(parsed_version, self.version):
637+
# === operator can match arbitrary (non-version) strings
638+
if self.operator == "===" and self._compare_arbitrary(
639+
version, self.version
640+
):
641+
yield version
642+
# != operator: non-version strings pass through (they're "not equal")
643+
elif self.operator == "!=":
644+
yield version
645+
elif operator_callable(parsed_version, self.version):
634646
# If it's not a prerelease or prereleases are allowed, yield it directly
635647
if not parsed_version.is_prerelease or include_prereleases:
636648
found_non_prereleases = True
@@ -943,13 +955,12 @@ def contains(
943955
True
944956
"""
945957
version = _coerce_version(item)
946-
if version is None:
947-
return False
948958

949-
if installed and version.is_prerelease:
959+
if version is not None and installed and version.is_prerelease:
950960
prereleases = True
951961

952-
return bool(list(self.filter([version], prereleases=prereleases)))
962+
check_item = item if version is None else version
963+
return bool(list(self.filter([check_item], prereleases=prereleases)))
953964

954965
def filter(
955966
self, iterable: Iterable[UnparsedVersionVar], prereleases: bool | None = None
@@ -1029,9 +1040,7 @@ def filter(
10291040

10301041
for item in iterable:
10311042
parsed_version = _coerce_version(item)
1032-
if parsed_version is None:
1033-
continue
1034-
if parsed_version.is_prerelease:
1043+
if parsed_version is not None and parsed_version.is_prerelease:
10351044
found_prereleases.append(item)
10361045
else:
10371046
filtered.append(item)

tests/test_specifiers.py

Lines changed: 138 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -513,15 +513,26 @@ def test_specifiers(self, version: str, spec_str: str, expected: bool) -> None:
513513
assert not spec.contains(Version(version))
514514

515515
@pytest.mark.parametrize(
516-
("spec_str", "version"),
516+
("spec_str", "version", "expected"),
517517
[
518-
("==1.0", "not a valid version"),
519-
("===invalid", "invalid"),
518+
("==1.0", "not a valid version", False),
519+
(">=1.0", "not a valid version", False),
520+
(">1.0", "not a valid version", False),
521+
("<=1.0", "not a valid version", False),
522+
("<1.0", "not a valid version", False),
523+
("~=1.0", "not a valid version", False),
524+
# Test invalid versions with != (should pass as "not equal")
525+
("!=1.0", "not a valid version", True),
526+
("!=1.0", "not a valid version", True),
527+
("!=2.0.*", "not a valid version", True),
528+
# Test with arbitrary equality (===)
529+
("===invalid", "invalid", True),
530+
("===foobar", "invalid", False),
520531
],
521532
)
522-
def test_invalid_spec(self, spec_str: str, version: str) -> None:
533+
def test_invalid_version(self, spec_str: str, version, expected: str) -> None:
523534
spec = Specifier(spec_str, prereleases=True)
524-
assert not spec.contains(version)
535+
assert spec.contains(version) == expected
525536

526537
@pytest.mark.parametrize(
527538
(
@@ -584,9 +595,15 @@ def test_specifier_prereleases_set(
584595
[
585596
("1.0.0", "===1.0", False),
586597
("1.0.dev0", "===1.0", False),
587-
# Test identity comparison by itself
598+
# Test exact arbitrary equality (===)
588599
("1.0", "===1.0", True),
589600
("1.0.dev0", "===1.0.dev0", True),
601+
# Test that local versions don't match
602+
("1.0+downstream1", "===1.0", False),
603+
("1.0", "===1.0+downstream1", False),
604+
# Test with arbitrary (non-version) strings
605+
("foobar", "===foobar", True),
606+
("foobar", "===baz", False),
590607
# Test case insensitivity for pre-release versions
591608
("1.0a1", "===1.0a1", True),
592609
("1.0A1", "===1.0A1", True),
@@ -632,17 +649,11 @@ def test_specifier_prereleases_set(
632649
("1.0A1.POST2.DEV3", "===1.0a1.post2.dev3", True),
633650
],
634651
)
635-
def test_specifiers_identity(
652+
def test_arbitrary_equality(
636653
self, version: str, spec_str: str, expected: bool
637654
) -> None:
638655
spec = Specifier(spec_str)
639-
640-
if expected:
641-
# Identity comparisons only support the plain string form
642-
assert version in spec
643-
else:
644-
# Identity comparisons only support the plain string form
645-
assert version not in spec
656+
assert spec.contains(version) == expected
646657

647658
@pytest.mark.parametrize(
648659
("specifier", "expected"),
@@ -726,6 +737,33 @@ def test_specifiers_prereleases(
726737
# Test that invalid versions are discarded
727738
(">=1.0", None, None, ["not a valid version"], []),
728739
(">=1.0", None, None, ["1.0", "not a valid version"], ["1.0"]),
740+
# Test arbitrary equality (===)
741+
("===foobar", None, None, ["foobar", "foo", "bar"], ["foobar"]),
742+
("===foobar", None, None, ["foo", "bar"], []),
743+
# Test that === does not match with zero padding
744+
("===1.0", None, None, ["1.0", "1.0.0", "2.0"], ["1.0"]),
745+
# Test that === does not match with local versions
746+
("===1.0", None, None, ["1.0", "1.0+downstream1"], ["1.0"]),
747+
# Test === with mix of valid versions and arbitrary strings
748+
(
749+
"===foobar",
750+
None,
751+
None,
752+
["foobar", "1.0", "2.0a1", "invalid"],
753+
["foobar"],
754+
),
755+
("===1.0", None, None, ["1.0", "foobar", "invalid", "1.0.0"], ["1.0"]),
756+
# Test != with invalid versions (should pass through as "not equal")
757+
("!=1.0", None, None, ["invalid", "foobar"], ["invalid", "foobar"]),
758+
("!=1.0", None, None, ["1.0", "invalid", "2.0"], ["invalid", "2.0"]),
759+
(
760+
"!=2.0.*",
761+
None,
762+
None,
763+
["invalid", "foobar", "2.0"],
764+
["invalid", "foobar"],
765+
),
766+
("!=2.0.*", None, None, ["1.0", "invalid", "2.0.0"], ["1.0", "invalid"]),
729767
],
730768
)
731769
def test_specifier_filter(
@@ -1041,12 +1079,61 @@ def test_specifier_contains_installed_prereleases(
10411079
(">=1.0,<=2.0dev", True, False, ["1.0", "1.5a1"], ["1.0"]),
10421080
(">=1.0,<=2.0dev", False, True, ["1.0", "1.5a1"], ["1.0", "1.5a1"]),
10431081
# Test that invalid versions are discarded
1044-
("", None, None, ["invalid version"], []),
1082+
("", None, None, ["invalid version"], ["invalid version"]),
10451083
("", None, False, ["invalid version"], []),
10461084
("", False, None, ["invalid version"], []),
1047-
("", None, None, ["1.0", "invalid version"], ["1.0"]),
1085+
("", None, None, ["1.0", "invalid version"], ["1.0", "invalid version"]),
10481086
("", None, False, ["1.0", "invalid version"], ["1.0"]),
10491087
("", False, None, ["1.0", "invalid version"], ["1.0"]),
1088+
# Test arbitrary equality (===)
1089+
("===foobar", None, None, ["foobar", "foo", "bar"], ["foobar"]),
1090+
("===foobar", None, None, ["foo", "bar"], []),
1091+
# Test that === does not match with zero padding
1092+
("===1.0", None, None, ["1.0", "1.0.0", "2.0"], ["1.0"]),
1093+
# Test that === does not match with local versions
1094+
("===1.0", None, None, ["1.0", "1.0+downstream1"], ["1.0"]),
1095+
# Test === combined with other operators (arbitrary string)
1096+
(">=1.0,===foobar", None, None, ["foobar", "1.0", "2.0"], []),
1097+
("!= 2.0,===foobar", None, None, ["foobar", "2.0", "bar"], ["foobar"]),
1098+
# Test === combined with other operators (version string)
1099+
(">=1.0,===1.5", None, None, ["1.0", "1.5", "2.0"], ["1.5"]),
1100+
(">=2.0,===1.5", None, None, ["1.0", "1.5", "2.0"], []),
1101+
# Test === with mix of valid and invalid versions
1102+
(
1103+
"===foobar",
1104+
None,
1105+
None,
1106+
["foobar", "1.0", "invalid", "2.0a1"],
1107+
["foobar"],
1108+
),
1109+
("===1.0", None, None, ["1.0", "foobar", "invalid", "1.0.0"], ["1.0"]),
1110+
(">=1.0,===1.5", None, None, ["1.5", "foobar", "invalid"], ["1.5"]),
1111+
# Test != with invalid versions (should pass through as "not equal")
1112+
("!=1.0", None, None, ["invalid", "foobar"], ["invalid", "foobar"]),
1113+
("!=1.0", None, None, ["1.0", "invalid", "2.0"], ["invalid", "2.0"]),
1114+
(
1115+
"!=2.0.*",
1116+
None,
1117+
None,
1118+
["invalid", "foobar", "2.0"],
1119+
["invalid", "foobar"],
1120+
),
1121+
("!=2.0.*", None, None, ["1.0", "invalid", "2.0.0"], ["1.0", "invalid"]),
1122+
# Test != with invalid versions combined with other operators
1123+
(
1124+
"!=1.0,!=2.0",
1125+
None,
1126+
None,
1127+
["invalid", "1.0", "2.0", "3.0"],
1128+
["invalid", "3.0"],
1129+
),
1130+
(
1131+
">=1.0,!=2.0",
1132+
None,
1133+
None,
1134+
["invalid", "1.0", "2.0", "3.0"],
1135+
["1.0", "3.0"],
1136+
),
10501137
],
10511138
)
10521139
def test_specifier_filter(
@@ -1442,6 +1529,41 @@ def test_contains_rejects_invalid_specifier(
14421529
spec = SpecifierSet(specifier, prereleases=True)
14431530
assert not spec.contains(input)
14441531

1532+
@pytest.mark.parametrize(
1533+
("version", "specifier", "expected"),
1534+
[
1535+
# Test arbitrary equality (===) with arbitrary strings
1536+
("foobar", "===foobar", True),
1537+
("foo", "===foobar", False),
1538+
("bar", "===foobar", False),
1539+
# Test that === does not match with zero padding
1540+
("1.0", "===1.0", True),
1541+
("1.0.0", "===1.0", False),
1542+
# Test that === does not match with local versions
1543+
("1.0", "===1.0+downstream1", False),
1544+
("1.0+downstream1", "===1.0", False),
1545+
# Test === combined with other operators (arbitrary string)
1546+
("foobar", "===foobar,!=1.0", True),
1547+
("1.0", "===foobar,!=1.0", False),
1548+
("foobar", ">=1.0,===foobar", False),
1549+
# Test === combined with other operators (version string)
1550+
("1.5", ">=1.0,===1.5", True),
1551+
("1.5", ">=2.0,===1.5", False), # Doesn't meet >=2.0
1552+
("2.5", ">=1.0,===2.5", True),
1553+
# Test != with invalid versions (should pass as "not equal")
1554+
("invalid", "!=1.0", True),
1555+
("foobar", "!=1.0", True),
1556+
("invalid", "!=2.0.*", True),
1557+
# Test != with invalid versions combined with other operators
1558+
("invalid", "!=1.0,!=2.0", True),
1559+
("foobar", ">=1.0,!=2.0", False),
1560+
("1.5", ">=1.0,!=2.0", True),
1561+
],
1562+
)
1563+
def test_contains_arbitrary_equality_contains(self, version, specifier, expected):
1564+
spec = SpecifierSet(specifier)
1565+
assert spec.contains(version) == expected
1566+
14451567
@pytest.mark.parametrize(
14461568
("specifier", "expected"),
14471569
[

0 commit comments

Comments
 (0)