Skip to content

Commit b1fbe73

Browse files
committed
Add a _require_spec_version to pass type checking
1 parent 621112e commit b1fbe73

File tree

1 file changed

+25
-11
lines changed

1 file changed

+25
-11
lines changed

src/packaging/specifiers.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -268,15 +268,28 @@ def __init__(self, spec: str = "", prereleases: bool | None = None) -> None:
268268
# Specifier version cache
269269
self._spec_version: tuple[str, Version] | None = None
270270

271-
def _get_spec_version(self, version: str) -> Version:
271+
def _get_spec_version(self, version: str) -> Version | None:
272272
"""One element cache, as only one spec Version is needed per Specifier."""
273273
if self._spec_version is not None and self._spec_version[0] == version:
274274
return self._spec_version[1]
275275

276-
version_specifier = Version(version)
276+
version_specifier = _coerce_version(version)
277+
if version_specifier is None:
278+
return None
279+
277280
self._spec_version = (version, version_specifier)
278281
return version_specifier
279282

283+
def _require_spec_version(self, version: str) -> Version:
284+
"""Get spec version, asserting it's valid (not for === operator).
285+
286+
This method should only be called for operators where version
287+
strings are guaranteed to be valid PEP 440 versions (not ===).
288+
"""
289+
spec_version = self._get_spec_version(version)
290+
assert spec_version is not None
291+
return spec_version
292+
280293
@property
281294
def prereleases(self) -> bool | None:
282295
# If there is an explicit prereleases set for this, then we'll just
@@ -295,13 +308,13 @@ def prereleases(self) -> bool | None:
295308

296309
# "===" can have arbitrary string versions, so we cannot
297310
# parse those, we take prereleases as False for those.
298-
version = _coerce_version(item)
311+
version = self._get_spec_version(item)
299312
if version is None:
300313
return None
301314

302315
# For all other operators, use the check if spec Version
303316
# object implies pre-releases.
304-
if self._get_spec_version(version).is_prerelease:
317+
if version.is_prerelease:
305318
return True
306319

307320
return False
@@ -362,9 +375,10 @@ def _canonical_spec(self) -> tuple[str, str]:
362375
if operator == "===" or version.endswith(".*"):
363376
return operator, version
364377

378+
spec_version = self._require_spec_version(version)
379+
365380
canonical_version = canonicalize_version(
366-
self._get_spec_version(version),
367-
strip_trailing_zero=(operator != "~="),
381+
spec_version, strip_trailing_zero=(operator != "~=")
368382
)
369383

370384
return operator, canonical_version
@@ -457,7 +471,7 @@ def _compare_equal(self, prospective: Version, spec: str) -> bool:
457471
return shortened_prospective == split_spec
458472
else:
459473
# Convert our spec string into a Version
460-
spec_version = self._get_spec_version(spec)
474+
spec_version = self._require_spec_version(spec)
461475

462476
# If the specifier does not have a local segment, then we want to
463477
# act as if the prospective version also does not have a local
@@ -474,18 +488,18 @@ def _compare_less_than_equal(self, prospective: Version, spec: str) -> bool:
474488
# NB: Local version identifiers are NOT permitted in the version
475489
# specifier, so local version labels can be universally removed from
476490
# the prospective version.
477-
return _public_version(prospective) <= self._get_spec_version(spec)
491+
return _public_version(prospective) <= self._require_spec_version(spec)
478492

479493
def _compare_greater_than_equal(self, prospective: Version, spec: str) -> bool:
480494
# NB: Local version identifiers are NOT permitted in the version
481495
# specifier, so local version labels can be universally removed from
482496
# the prospective version.
483-
return _public_version(prospective) >= self._get_spec_version(spec)
497+
return _public_version(prospective) >= self._require_spec_version(spec)
484498

485499
def _compare_less_than(self, prospective: Version, spec_str: str) -> bool:
486500
# Convert our spec to a Version instance, since we'll want to work with
487501
# it as a version.
488-
spec = self._get_spec_version(spec_str)
502+
spec = self._require_spec_version(spec_str)
489503

490504
# Check to see if the prospective version is less than the spec
491505
# version. If it's not we can short circuit and just return False now
@@ -512,7 +526,7 @@ def _compare_less_than(self, prospective: Version, spec_str: str) -> bool:
512526
def _compare_greater_than(self, prospective: Version, spec_str: str) -> bool:
513527
# Convert our spec to a Version instance, since we'll want to work with
514528
# it as a version.
515-
spec = self._get_spec_version(spec_str)
529+
spec = self._require_spec_version(spec_str)
516530

517531
# Check to see if the prospective version is greater than the spec
518532
# version. If it's not we can short circuit and just return False now

0 commit comments

Comments
 (0)