Skip to content

Commit fdab81f

Browse files
committed
Support arbitrary equality on arbitrary strings
1 parent b68980b commit fdab81f

File tree

2 files changed

+163
-32
lines changed

2 files changed

+163
-32
lines changed

src/packaging/specifiers.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -248,24 +248,30 @@ def __init__(self, spec: str = "", prereleases: bool | None = None) -> None:
248248
self._prereleases = prereleases
249249

250250
@property
251-
def prereleases(self) -> bool:
251+
def prereleases(self) -> bool | None:
252252
# If there is an explicit prereleases set for this, then we'll just
253253
# blindly use that.
254254
if self._prereleases is not None:
255255
return self._prereleases
256256

257257
# Only the "!=" operator does not imply prereleases when
258258
# the version in the specifier is a prerelease.
259-
operator, version = self._spec
259+
operator, item = self._spec
260260
if operator != "!=":
261261
# The == specifier can include a trailing .*, if it does we
262262
# want to remove before parsing.
263-
if operator == "==" and version.endswith(".*"):
264-
version = version[:-2]
263+
if operator == "==" and item.endswith(".*"):
264+
item = item[:-2]
265+
266+
# "===" can have arbitrary string versions, so we cannot
267+
# parse those, we take prereleases as False for those.
268+
version = _coerce_version(item)
269+
if version is None:
270+
return None
265271

266272
# Parse the version, and if it is a pre-release than this
267273
# specifier allows pre-releases.
268-
if Version(version).is_prerelease:
274+
if version.is_prerelease:
269275
return True
270276

271277
return False
@@ -500,7 +506,7 @@ def _compare_greater_than(self, prospective: Version, spec_str: str) -> bool:
500506
# same version in the spec.
501507
return True
502508

503-
def _compare_arbitrary(self, prospective: Version, spec: str) -> bool:
509+
def _compare_arbitrary(self, prospective: Version | str, spec: str) -> bool:
504510
return str(prospective).lower() == str(spec).lower()
505511

506512
def __contains__(self, item: str | Version) -> bool:
@@ -590,9 +596,15 @@ def filter(
590596
for version in iterable:
591597
parsed_version = _coerce_version(version)
592598
if parsed_version is None:
593-
continue
594-
595-
if operator_callable(parsed_version, self.version):
599+
# === operator can match arbitrary (non-version) strings
600+
if self.operator == "===" and self._compare_arbitrary(
601+
version, self.version
602+
):
603+
yield version
604+
# != operator: non-version strings pass through (they're "not equal")
605+
elif self.operator == "!=":
606+
yield version
607+
elif operator_callable(parsed_version, self.version):
596608
# If it's not a prerelease or prereleases are allowed, yield it directly
597609
if not parsed_version.is_prerelease or include_prereleases:
598610
found_non_prereleases = True
@@ -905,13 +917,12 @@ def contains(
905917
True
906918
"""
907919
version = _coerce_version(item)
908-
if version is None:
909-
return False
910920

911-
if installed and version.is_prerelease:
921+
if version is not None and installed and version.is_prerelease:
912922
prereleases = True
913923

914-
return bool(list(self.filter([version], prereleases=prereleases)))
924+
check_item = item if version is None else version
925+
return bool(list(self.filter([check_item], prereleases=prereleases)))
915926

916927
def filter(
917928
self, iterable: Iterable[UnparsedVersionVar], prereleases: bool | None = None
@@ -991,9 +1002,7 @@ def filter(
9911002

9921003
for item in iterable:
9931004
parsed_version = _coerce_version(item)
994-
if parsed_version is None:
995-
continue
996-
if parsed_version.is_prerelease:
1005+
if parsed_version is not None and parsed_version.is_prerelease:
9971006
found_prereleases.append(item)
9981007
else:
9991008
filtered.append(item)

tests/test_specifiers.py

Lines changed: 138 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -495,15 +495,26 @@ def test_specifiers(self, version, spec, expected):
495495
assert not spec.contains(Version(version))
496496

497497
@pytest.mark.parametrize(
498-
("spec", "version"),
498+
("spec", "version", "expected"),
499499
[
500-
("==1.0", "not a valid version"),
501-
("===invalid", "invalid"),
500+
("==1.0", "not a valid version", False),
501+
(">=1.0", "not a valid version", False),
502+
(">1.0", "not a valid version", False),
503+
("<=1.0", "not a valid version", False),
504+
("<1.0", "not a valid version", False),
505+
("~=1.0", "not a valid version", False),
506+
# Test invalid versions with != (should pass as "not equal")
507+
("!=1.0", "not a valid version", True),
508+
("!=1.0", "not a valid version", True),
509+
("!=2.0.*", "not a valid version", True),
510+
# Test with arbitrary equality (===)
511+
("===invalid", "invalid", True),
512+
("===foobar", "invalid", False),
502513
],
503514
)
504-
def test_invalid_spec(self, spec, version):
515+
def test_invalid_version(self, spec, version, expected):
505516
spec = Specifier(spec, prereleases=True)
506-
assert not spec.contains(version)
517+
assert spec.contains(version) == expected
507518

508519
@pytest.mark.parametrize(
509520
(
@@ -569,20 +580,20 @@ def test_specifier_prereleases_set(
569580
[
570581
("1.0.0", "===1.0", False),
571582
("1.0.dev0", "===1.0", False),
572-
# Test identity comparison by itself
583+
# Test exact arbitrary equality (===)
573584
("1.0", "===1.0", True),
574585
("1.0.dev0", "===1.0.dev0", True),
586+
# Test that local versions don't match
587+
("1.0+downstream1", "===1.0", False),
588+
("1.0", "===1.0+downstream1", False),
589+
# Test with arbitrary (non-version) strings
590+
("foobar", "===foobar", True),
591+
("foobar", "===baz", False),
575592
],
576593
)
577-
def test_specifiers_identity(self, version, spec, expected):
594+
def test_arbitrary_equality(self, version, spec, expected):
578595
spec = Specifier(spec)
579-
580-
if expected:
581-
# Identity comparisons only support the plain string form
582-
assert version in spec
583-
else:
584-
# Identity comparisons only support the plain string form
585-
assert version not in spec
596+
assert spec.contains(version) == expected
586597

587598
@pytest.mark.parametrize(
588599
("specifier", "expected"),
@@ -659,6 +670,33 @@ def test_specifiers_prereleases(
659670
# Test that invalid versions are discarded
660671
(">=1.0", None, None, ["not a valid version"], []),
661672
(">=1.0", None, None, ["1.0", "not a valid version"], ["1.0"]),
673+
# Test arbitrary equality (===)
674+
("===foobar", None, None, ["foobar", "foo", "bar"], ["foobar"]),
675+
("===foobar", None, None, ["foo", "bar"], []),
676+
# Test that === does not match with zero padding
677+
("===1.0", None, None, ["1.0", "1.0.0", "2.0"], ["1.0"]),
678+
# Test that === does not match with local versions
679+
("===1.0", None, None, ["1.0", "1.0+downstream1"], ["1.0"]),
680+
# Test === with mix of valid versions and arbitrary strings
681+
(
682+
"===foobar",
683+
None,
684+
None,
685+
["foobar", "1.0", "2.0a1", "invalid"],
686+
["foobar"],
687+
),
688+
("===1.0", None, None, ["1.0", "foobar", "invalid", "1.0.0"], ["1.0"]),
689+
# Test != with invalid versions (should pass through as "not equal")
690+
("!=1.0", None, None, ["invalid", "foobar"], ["invalid", "foobar"]),
691+
("!=1.0", None, None, ["1.0", "invalid", "2.0"], ["invalid", "2.0"]),
692+
(
693+
"!=2.0.*",
694+
None,
695+
None,
696+
["invalid", "foobar", "2.0"],
697+
["invalid", "foobar"],
698+
),
699+
("!=2.0.*", None, None, ["1.0", "invalid", "2.0.0"], ["1.0", "invalid"]),
662700
],
663701
)
664702
def test_specifier_filter(
@@ -975,12 +1013,61 @@ def test_specifier_contains_installed_prereleases(
9751013
(">=1.0,<=2.0dev", True, False, ["1.0", "1.5a1"], ["1.0"]),
9761014
(">=1.0,<=2.0dev", False, True, ["1.0", "1.5a1"], ["1.0", "1.5a1"]),
9771015
# Test that invalid versions are discarded
978-
("", None, None, ["invalid version"], []),
1016+
("", None, None, ["invalid version"], ["invalid version"]),
9791017
("", None, False, ["invalid version"], []),
9801018
("", False, None, ["invalid version"], []),
981-
("", None, None, ["1.0", "invalid version"], ["1.0"]),
1019+
("", None, None, ["1.0", "invalid version"], ["1.0", "invalid version"]),
9821020
("", None, False, ["1.0", "invalid version"], ["1.0"]),
9831021
("", False, None, ["1.0", "invalid version"], ["1.0"]),
1022+
# Test arbitrary equality (===)
1023+
("===foobar", None, None, ["foobar", "foo", "bar"], ["foobar"]),
1024+
("===foobar", None, None, ["foo", "bar"], []),
1025+
# Test that === does not match with zero padding
1026+
("===1.0", None, None, ["1.0", "1.0.0", "2.0"], ["1.0"]),
1027+
# Test that === does not match with local versions
1028+
("===1.0", None, None, ["1.0", "1.0+downstream1"], ["1.0"]),
1029+
# Test === combined with other operators (arbitrary string)
1030+
(">=1.0,===foobar", None, None, ["foobar", "1.0", "2.0"], []),
1031+
("!= 2.0,===foobar", None, None, ["foobar", "2.0", "bar"], ["foobar"]),
1032+
# Test === combined with other operators (version string)
1033+
(">=1.0,===1.5", None, None, ["1.0", "1.5", "2.0"], ["1.5"]),
1034+
(">=2.0,===1.5", None, None, ["1.0", "1.5", "2.0"], []),
1035+
# Test === with mix of valid and invalid versions
1036+
(
1037+
"===foobar",
1038+
None,
1039+
None,
1040+
["foobar", "1.0", "invalid", "2.0a1"],
1041+
["foobar"],
1042+
),
1043+
("===1.0", None, None, ["1.0", "foobar", "invalid", "1.0.0"], ["1.0"]),
1044+
(">=1.0,===1.5", None, None, ["1.5", "foobar", "invalid"], ["1.5"]),
1045+
# Test != with invalid versions (should pass through as "not equal")
1046+
("!=1.0", None, None, ["invalid", "foobar"], ["invalid", "foobar"]),
1047+
("!=1.0", None, None, ["1.0", "invalid", "2.0"], ["invalid", "2.0"]),
1048+
(
1049+
"!=2.0.*",
1050+
None,
1051+
None,
1052+
["invalid", "foobar", "2.0"],
1053+
["invalid", "foobar"],
1054+
),
1055+
("!=2.0.*", None, None, ["1.0", "invalid", "2.0.0"], ["1.0", "invalid"]),
1056+
# Test != with invalid versions combined with other operators
1057+
(
1058+
"!=1.0,!=2.0",
1059+
None,
1060+
None,
1061+
["invalid", "1.0", "2.0", "3.0"],
1062+
["invalid", "3.0"],
1063+
),
1064+
(
1065+
">=1.0,!=2.0",
1066+
None,
1067+
None,
1068+
["invalid", "1.0", "2.0", "3.0"],
1069+
["1.0", "3.0"],
1070+
),
9841071
],
9851072
)
9861073
def test_specifier_filter(
@@ -1363,6 +1450,41 @@ def test_contains_rejects_invalid_specifier(self, specifier, input):
13631450
spec = SpecifierSet(specifier, prereleases=True)
13641451
assert not spec.contains(input)
13651452

1453+
@pytest.mark.parametrize(
1454+
("version", "specifier", "expected"),
1455+
[
1456+
# Test arbitrary equality (===) with arbitrary strings
1457+
("foobar", "===foobar", True),
1458+
("foo", "===foobar", False),
1459+
("bar", "===foobar", False),
1460+
# Test that === does not match with zero padding
1461+
("1.0", "===1.0", True),
1462+
("1.0.0", "===1.0", False),
1463+
# Test that === does not match with local versions
1464+
("1.0", "===1.0+downstream1", False),
1465+
("1.0+downstream1", "===1.0", False),
1466+
# Test === combined with other operators (arbitrary string)
1467+
("foobar", "===foobar,!=1.0", True),
1468+
("1.0", "===foobar,!=1.0", False),
1469+
("foobar", ">=1.0,===foobar", False),
1470+
# Test === combined with other operators (version string)
1471+
("1.5", ">=1.0,===1.5", True),
1472+
("1.5", ">=2.0,===1.5", False), # Doesn't meet >=2.0
1473+
("2.5", ">=1.0,===2.5", True),
1474+
# Test != with invalid versions (should pass as "not equal")
1475+
("invalid", "!=1.0", True),
1476+
("foobar", "!=1.0", True),
1477+
("invalid", "!=2.0.*", True),
1478+
# Test != with invalid versions combined with other operators
1479+
("invalid", "!=1.0,!=2.0", True),
1480+
("foobar", ">=1.0,!=2.0", False),
1481+
("1.5", ">=1.0,!=2.0", True),
1482+
],
1483+
)
1484+
def test_contains_arbitrary_equality_contains(self, version, specifier, expected):
1485+
spec = SpecifierSet(specifier)
1486+
assert spec.contains(version) == expected
1487+
13661488
@pytest.mark.parametrize(
13671489
("specifier", "expected"),
13681490
[

0 commit comments

Comments
 (0)