Skip to content

Commit 93a7824

Browse files
committed
Support arbitrary equality on arbitrary strings
1 parent a70ead8 commit 93a7824

File tree

2 files changed

+161
-30
lines changed

2 files changed

+161
-30
lines changed

src/packaging/specifiers.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -278,21 +278,27 @@ def _get_spec_version(self, version: str) -> Version:
278278
return version_specifier
279279

280280
@property
281-
def prereleases(self) -> bool:
281+
def prereleases(self) -> bool | None:
282282
# If there is an explicit prereleases set for this, then we'll just
283283
# blindly use that.
284284
if self._prereleases is not None:
285285
return self._prereleases
286286

287287
# Only the "!=" operator does not imply prereleases when
288288
# the version in the specifier is a prerelease.
289-
operator, version = self._spec
289+
operator, item = self._spec
290290
if operator != "!=":
291291
# The == specifier with trailing .* cannot include prereleases
292292
# e.g. "==1.0a1.*" is not valid.
293-
if operator == "==" and version.endswith(".*"):
293+
if operator == "==" and item.endswith(".*"):
294294
return False
295295

296+
# "===" can have arbitrary string versions, so we cannot
297+
# parse those, we take prereleases as False for those.
298+
version = _coerce_version(item)
299+
if version is None:
300+
return None
301+
296302
# For all other operators, use the check if spec Version
297303
# object implies pre-releases.
298304
if self._get_spec_version(version).is_prerelease:
@@ -537,7 +543,7 @@ def _compare_greater_than(self, prospective: Version, spec_str: str) -> bool:
537543
# same version in the spec.
538544
return True
539545

540-
def _compare_arbitrary(self, prospective: Version, spec: str) -> bool:
546+
def _compare_arbitrary(self, prospective: Version | str, spec: str) -> bool:
541547
return str(prospective).lower() == str(spec).lower()
542548

543549
def __contains__(self, item: str | Version) -> bool:
@@ -627,9 +633,15 @@ def filter(
627633
for version in iterable:
628634
parsed_version = _coerce_version(version)
629635
if parsed_version is None:
630-
continue
631-
632-
if operator_callable(parsed_version, self.version):
636+
# === operator can match arbitrary (non-version) strings
637+
if self.operator == "===" and self._compare_arbitrary(
638+
version, self.version
639+
):
640+
yield version
641+
# != operator: non-version strings pass through (they're "not equal")
642+
elif self.operator == "!=":
643+
yield version
644+
elif operator_callable(parsed_version, self.version):
633645
# If it's not a prerelease or prereleases are allowed, yield it directly
634646
if not parsed_version.is_prerelease or include_prereleases:
635647
found_non_prereleases = True
@@ -944,13 +956,12 @@ def contains(
944956
True
945957
"""
946958
version = _coerce_version(item)
947-
if version is None:
948-
return False
949959

950-
if installed and version.is_prerelease:
960+
if version is not None and installed and version.is_prerelease:
951961
prereleases = True
952962

953-
return bool(list(self.filter([version], prereleases=prereleases)))
963+
check_item = item if version is None else version
964+
return bool(list(self.filter([check_item], prereleases=prereleases)))
954965

955966
def filter(
956967
self, iterable: Iterable[UnparsedVersionVar], prereleases: bool | None = None
@@ -1030,9 +1041,7 @@ def filter(
10301041

10311042
for item in iterable:
10321043
parsed_version = _coerce_version(item)
1033-
if parsed_version is None:
1034-
continue
1035-
if parsed_version.is_prerelease:
1044+
if parsed_version is not None and parsed_version.is_prerelease:
10361045
found_prereleases.append(item)
10371046
else:
10381047
filtered.append(item)

tests/test_specifiers.py

Lines changed: 138 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -517,15 +517,26 @@ def test_specifiers(self, version: str, spec_str: str, expected: bool) -> None:
517517
assert not spec.contains(Version(version))
518518

519519
@pytest.mark.parametrize(
520-
("spec_str", "version"),
520+
("spec_str", "version", "expected"),
521521
[
522-
("==1.0", "not a valid version"),
523-
("===invalid", "invalid"),
522+
("==1.0", "not a valid version", False),
523+
(">=1.0", "not a valid version", False),
524+
(">1.0", "not a valid version", False),
525+
("<=1.0", "not a valid version", False),
526+
("<1.0", "not a valid version", False),
527+
("~=1.0", "not a valid version", False),
528+
# Test invalid versions with != (should pass as "not equal")
529+
("!=1.0", "not a valid version", True),
530+
("!=1.0", "not a valid version", True),
531+
("!=2.0.*", "not a valid version", True),
532+
# Test with arbitrary equality (===)
533+
("===invalid", "invalid", True),
534+
("===foobar", "invalid", False),
524535
],
525536
)
526-
def test_invalid_spec(self, spec_str: str, version: str) -> None:
537+
def test_invalid_version(self, spec_str: str, version, expected: str) -> None:
527538
spec = Specifier(spec_str, prereleases=True)
528-
assert not spec.contains(version)
539+
assert spec.contains(version) == expected
529540

530541
@pytest.mark.parametrize(
531542
(
@@ -588,9 +599,15 @@ def test_specifier_prereleases_set(
588599
[
589600
("1.0.0", "===1.0", False),
590601
("1.0.dev0", "===1.0", False),
591-
# Test identity comparison by itself
602+
# Test exact arbitrary equality (===)
592603
("1.0", "===1.0", True),
593604
("1.0.dev0", "===1.0.dev0", True),
605+
# Test that local versions don't match
606+
("1.0+downstream1", "===1.0", False),
607+
("1.0", "===1.0+downstream1", False),
608+
# Test with arbitrary (non-version) strings
609+
("foobar", "===foobar", True),
610+
("foobar", "===baz", False),
594611
# Test case insensitivity for pre-release versions
595612
("1.0a1", "===1.0a1", True),
596613
("1.0A1", "===1.0A1", True),
@@ -636,17 +653,11 @@ def test_specifier_prereleases_set(
636653
("1.0A1.POST2.DEV3", "===1.0a1.post2.dev3", True),
637654
],
638655
)
639-
def test_specifiers_identity(
656+
def test_arbitrary_equality(
640657
self, version: str, spec_str: str, expected: bool
641658
) -> None:
642659
spec = Specifier(spec_str)
643-
644-
if expected:
645-
# Identity comparisons only support the plain string form
646-
assert version in spec
647-
else:
648-
# Identity comparisons only support the plain string form
649-
assert version not in spec
660+
assert spec.contains(version) == expected
650661

651662
@pytest.mark.parametrize(
652663
("specifier", "expected"),
@@ -730,6 +741,33 @@ def test_specifiers_prereleases(
730741
# Test that invalid versions are discarded
731742
(">=1.0", None, None, ["not a valid version"], []),
732743
(">=1.0", None, None, ["1.0", "not a valid version"], ["1.0"]),
744+
# Test arbitrary equality (===)
745+
("===foobar", None, None, ["foobar", "foo", "bar"], ["foobar"]),
746+
("===foobar", None, None, ["foo", "bar"], []),
747+
# Test that === does not match with zero padding
748+
("===1.0", None, None, ["1.0", "1.0.0", "2.0"], ["1.0"]),
749+
# Test that === does not match with local versions
750+
("===1.0", None, None, ["1.0", "1.0+downstream1"], ["1.0"]),
751+
# Test === with mix of valid versions and arbitrary strings
752+
(
753+
"===foobar",
754+
None,
755+
None,
756+
["foobar", "1.0", "2.0a1", "invalid"],
757+
["foobar"],
758+
),
759+
("===1.0", None, None, ["1.0", "foobar", "invalid", "1.0.0"], ["1.0"]),
760+
# Test != with invalid versions (should pass through as "not equal")
761+
("!=1.0", None, None, ["invalid", "foobar"], ["invalid", "foobar"]),
762+
("!=1.0", None, None, ["1.0", "invalid", "2.0"], ["invalid", "2.0"]),
763+
(
764+
"!=2.0.*",
765+
None,
766+
None,
767+
["invalid", "foobar", "2.0"],
768+
["invalid", "foobar"],
769+
),
770+
("!=2.0.*", None, None, ["1.0", "invalid", "2.0.0"], ["1.0", "invalid"]),
733771
],
734772
)
735773
def test_specifier_filter(
@@ -1190,12 +1228,61 @@ def test_specifier_contains_installed_prereleases(
11901228
(">=1.0,<=2.0dev", True, False, ["1.0", "1.5a1"], ["1.0"]),
11911229
(">=1.0,<=2.0dev", False, True, ["1.0", "1.5a1"], ["1.0", "1.5a1"]),
11921230
# Test that invalid versions are discarded
1193-
("", None, None, ["invalid version"], []),
1231+
("", None, None, ["invalid version"], ["invalid version"]),
11941232
("", None, False, ["invalid version"], []),
11951233
("", False, None, ["invalid version"], []),
1196-
("", None, None, ["1.0", "invalid version"], ["1.0"]),
1234+
("", None, None, ["1.0", "invalid version"], ["1.0", "invalid version"]),
11971235
("", None, False, ["1.0", "invalid version"], ["1.0"]),
11981236
("", False, None, ["1.0", "invalid version"], ["1.0"]),
1237+
# Test arbitrary equality (===)
1238+
("===foobar", None, None, ["foobar", "foo", "bar"], ["foobar"]),
1239+
("===foobar", None, None, ["foo", "bar"], []),
1240+
# Test that === does not match with zero padding
1241+
("===1.0", None, None, ["1.0", "1.0.0", "2.0"], ["1.0"]),
1242+
# Test that === does not match with local versions
1243+
("===1.0", None, None, ["1.0", "1.0+downstream1"], ["1.0"]),
1244+
# Test === combined with other operators (arbitrary string)
1245+
(">=1.0,===foobar", None, None, ["foobar", "1.0", "2.0"], []),
1246+
("!= 2.0,===foobar", None, None, ["foobar", "2.0", "bar"], ["foobar"]),
1247+
# Test === combined with other operators (version string)
1248+
(">=1.0,===1.5", None, None, ["1.0", "1.5", "2.0"], ["1.5"]),
1249+
(">=2.0,===1.5", None, None, ["1.0", "1.5", "2.0"], []),
1250+
# Test === with mix of valid and invalid versions
1251+
(
1252+
"===foobar",
1253+
None,
1254+
None,
1255+
["foobar", "1.0", "invalid", "2.0a1"],
1256+
["foobar"],
1257+
),
1258+
("===1.0", None, None, ["1.0", "foobar", "invalid", "1.0.0"], ["1.0"]),
1259+
(">=1.0,===1.5", None, None, ["1.5", "foobar", "invalid"], ["1.5"]),
1260+
# Test != with invalid versions (should pass through as "not equal")
1261+
("!=1.0", None, None, ["invalid", "foobar"], ["invalid", "foobar"]),
1262+
("!=1.0", None, None, ["1.0", "invalid", "2.0"], ["invalid", "2.0"]),
1263+
(
1264+
"!=2.0.*",
1265+
None,
1266+
None,
1267+
["invalid", "foobar", "2.0"],
1268+
["invalid", "foobar"],
1269+
),
1270+
("!=2.0.*", None, None, ["1.0", "invalid", "2.0.0"], ["1.0", "invalid"]),
1271+
# Test != with invalid versions combined with other operators
1272+
(
1273+
"!=1.0,!=2.0",
1274+
None,
1275+
None,
1276+
["invalid", "1.0", "2.0", "3.0"],
1277+
["invalid", "3.0"],
1278+
),
1279+
(
1280+
">=1.0,!=2.0",
1281+
None,
1282+
None,
1283+
["invalid", "1.0", "2.0", "3.0"],
1284+
["1.0", "3.0"],
1285+
),
11991286
],
12001287
)
12011288
def test_specifier_filter(
@@ -1591,6 +1678,41 @@ def test_contains_rejects_invalid_specifier(
15911678
spec = SpecifierSet(specifier, prereleases=True)
15921679
assert not spec.contains(input)
15931680

1681+
@pytest.mark.parametrize(
1682+
("version", "specifier", "expected"),
1683+
[
1684+
# Test arbitrary equality (===) with arbitrary strings
1685+
("foobar", "===foobar", True),
1686+
("foo", "===foobar", False),
1687+
("bar", "===foobar", False),
1688+
# Test that === does not match with zero padding
1689+
("1.0", "===1.0", True),
1690+
("1.0.0", "===1.0", False),
1691+
# Test that === does not match with local versions
1692+
("1.0", "===1.0+downstream1", False),
1693+
("1.0+downstream1", "===1.0", False),
1694+
# Test === combined with other operators (arbitrary string)
1695+
("foobar", "===foobar,!=1.0", True),
1696+
("1.0", "===foobar,!=1.0", False),
1697+
("foobar", ">=1.0,===foobar", False),
1698+
# Test === combined with other operators (version string)
1699+
("1.5", ">=1.0,===1.5", True),
1700+
("1.5", ">=2.0,===1.5", False), # Doesn't meet >=2.0
1701+
("2.5", ">=1.0,===2.5", True),
1702+
# Test != with invalid versions (should pass as "not equal")
1703+
("invalid", "!=1.0", True),
1704+
("foobar", "!=1.0", True),
1705+
("invalid", "!=2.0.*", True),
1706+
# Test != with invalid versions combined with other operators
1707+
("invalid", "!=1.0,!=2.0", True),
1708+
("foobar", ">=1.0,!=2.0", False),
1709+
("1.5", ">=1.0,!=2.0", True),
1710+
],
1711+
)
1712+
def test_contains_arbitrary_equality_contains(self, version, specifier, expected):
1713+
spec = SpecifierSet(specifier)
1714+
assert spec.contains(version) == expected
1715+
15941716
@pytest.mark.parametrize(
15951717
("specifier", "expected"),
15961718
[

0 commit comments

Comments
 (0)