diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index d32a5b3204..4dfd36aa49 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -66,7 +66,12 @@ from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.common import CONNECT_TIMEOUT from pymongo.daemon import _spawn_daemon -from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts, TextOpts +from pymongo.encryption_options import ( + AutoEncryptionOpts, + RangeOpts, + TextOpts, + check_min_pymongocrypt, +) from pymongo.errors import ( ConfigurationError, EncryptedCollectionError, @@ -675,6 +680,8 @@ def __init__( "python -m pip install --upgrade 'pymongo[encryption]'" ) + check_min_pymongocrypt() + if not isinstance(codec_options, CodecOptions): raise TypeError( f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}" diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/asynchronous/srv_resolver.py index 8d0d40c276..006abbb616 100644 --- a/pymongo/asynchronous/srv_resolver.py +++ b/pymongo/asynchronous/srv_resolver.py @@ -19,7 +19,7 @@ import random from typing import TYPE_CHECKING, Any, Optional, Union -from pymongo.common import CONNECT_TIMEOUT +from pymongo.common import CONNECT_TIMEOUT, check_for_min_version from pymongo.errors import ConfigurationError if TYPE_CHECKING: @@ -32,6 +32,14 @@ def _have_dnspython() -> bool: try: import dns # noqa: F401 + dns_version, required_version, is_valid = check_for_min_version("dnspython") + if not is_valid: + raise RuntimeError( + f"pymongo requires dnspython>={required_version}, " + f"found version {dns_version}. " + "Install a compatible version with pip" + ) + return True except ImportError: return False @@ -80,6 +88,8 @@ def __init__( srv_service_name: str, srv_max_hosts: int = 0, ): + # Ensure the version of dnspython is compatible. + _have_dnspython() self.__fqdn = fqdn self.__srv = srv_service_name self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT diff --git a/pymongo/common.py b/pymongo/common.py index 5210e72189..e23adac426 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -20,6 +20,7 @@ import warnings from collections import OrderedDict, abc from difflib import get_close_matches +from importlib.metadata import requires, version from typing import ( TYPE_CHECKING, Any, @@ -1092,3 +1093,91 @@ def has_c() -> bool: return True except ImportError: return False + + +class Version(tuple[int, ...]): + """A class that can be used to compare version strings.""" + + def __new__(cls, *version: int) -> Version: + padded_version = cls._padded(version, 4) + return super().__new__(cls, tuple(padded_version)) + + @classmethod + def _padded(cls, iter: Any, length: int, padding: int = 0) -> list[int]: + as_list = list(iter) + if len(as_list) < length: + for _ in range(length - len(as_list)): + as_list.append(padding) + return as_list + + @classmethod + def from_string(cls, version_string: str) -> Version: + mod = 0 + bump_patch_level = False + if version_string.endswith("+"): + version_string = version_string[0:-1] + mod = 1 + elif version_string.endswith("-pre-"): + version_string = version_string[0:-5] + mod = -1 + elif version_string.endswith("-"): + version_string = version_string[0:-1] + mod = -1 + # Deal with .devX substrings + if ".dev" in version_string: + version_string = version_string[0 : version_string.find(".dev")] + mod = -1 + # Deal with '-rcX' substrings + if "-rc" in version_string: + version_string = version_string[0 : version_string.find("-rc")] + mod = -1 + # Deal with git describe generated substrings + elif "-" in version_string: + version_string = version_string[0 : version_string.find("-")] + mod = -1 + bump_patch_level = True + + version = [int(part) for part in version_string.split(".")] + version = cls._padded(version, 3) + # Make from_string and from_version_array agree. For example: + # MongoDB Enterprise > db.runCommand('buildInfo').versionArray + # [ 3, 2, 1, -100 ] + # MongoDB Enterprise > db.runCommand('buildInfo').version + # 3.2.0-97-g1ef94fe + if bump_patch_level: + version[-1] += 1 + version.append(mod) + + return Version(*version) + + @classmethod + def from_version_array(cls, version_array: Any) -> Version: + version = list(version_array) + if version[-1] < 0: + version[-1] = -1 + version = cls._padded(version, 3) + return Version(*version) + + def at_least(self, *other_version: Any) -> bool: + return self >= Version(*other_version) + + def __str__(self) -> str: + return ".".join(map(str, self)) + + +def check_for_min_version(package_name: str) -> tuple[str, str, bool]: + """Test whether an installed package is of the desired version.""" + package_version_str = version(package_name) + package_version = Version.from_string(package_version_str) + # Dependency is expected to be in one of the forms: + # "pymongocrypt<2.0.0,>=1.13.0; extra == 'encryption'" + # 'dnspython<3.0.0,>=1.16.0' + # + requirements = requires("pymongo") + assert requirements is not None + requirement = [i for i in requirements if i.startswith(package_name)][0] # noqa: RUF015 + if ";" in requirement: + requirement = requirement.split(";")[0] + required_version = requirement[requirement.find(">=") + 2 :] + is_valid = package_version >= Version.from_string(required_version) + return package_version_str, required_version, is_valid diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index da34a3be52..b2037617b0 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -23,7 +23,7 @@ from pymongo.uri_parser_shared import _parse_kms_tls_options try: - import pymongocrypt # type:ignore[import-untyped] # noqa: F401 + import pymongocrypt # type:ignore[import-untyped] # noqa: F401 # Check for pymongocrypt>=1.10. from pymongocrypt import synchronous as _ # noqa: F401 @@ -32,7 +32,7 @@ except ImportError: _HAVE_PYMONGOCRYPT = False from bson import int64 -from pymongo.common import validate_is_mapping +from pymongo.common import check_for_min_version, validate_is_mapping from pymongo.errors import ConfigurationError if TYPE_CHECKING: @@ -40,6 +40,18 @@ from pymongo.typings import _AgnosticMongoClient +def check_min_pymongocrypt() -> None: + """Raise an appropriate error if the min pymongocrypt is not installed.""" + pymongocrypt_version, required_version, is_valid = check_for_min_version("pymongocrypt") + if not is_valid: + raise ConfigurationError( + f"client side encryption requires pymongocrypt>={required_version}, " + f"found version {pymongocrypt_version}. " + "Install a compatible version with: " + "python -m pip install 'pymongo[encryption]'" + ) + + class AutoEncryptionOpts: """Options to configure automatic client-side field level encryption.""" @@ -215,6 +227,7 @@ def __init__( "install a compatible version with: " "python -m pip install 'pymongo[encryption]'" ) + check_min_pymongocrypt() if encrypted_fields_map: validate_is_mapping("encrypted_fields_map", encrypted_fields_map) self._encrypted_fields_map = encrypted_fields_map diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index f9d51a9eab..2d666b9763 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -61,7 +61,12 @@ from pymongo import _csot from pymongo.common import CONNECT_TIMEOUT from pymongo.daemon import _spawn_daemon -from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts, TextOpts +from pymongo.encryption_options import ( + AutoEncryptionOpts, + RangeOpts, + TextOpts, + check_min_pymongocrypt, +) from pymongo.errors import ( ConfigurationError, EncryptedCollectionError, @@ -672,6 +677,8 @@ def __init__( "python -m pip install --upgrade 'pymongo[encryption]'" ) + check_min_pymongocrypt() + if not isinstance(codec_options, CodecOptions): raise TypeError( f"codec_options must be an instance of bson.codec_options.CodecOptions, not {type(codec_options)}" diff --git a/pymongo/synchronous/srv_resolver.py b/pymongo/synchronous/srv_resolver.py index f6e99a3ea8..8e492061ae 100644 --- a/pymongo/synchronous/srv_resolver.py +++ b/pymongo/synchronous/srv_resolver.py @@ -19,7 +19,7 @@ import random from typing import TYPE_CHECKING, Any, Optional, Union -from pymongo.common import CONNECT_TIMEOUT +from pymongo.common import CONNECT_TIMEOUT, check_for_min_version from pymongo.errors import ConfigurationError if TYPE_CHECKING: @@ -32,6 +32,14 @@ def _have_dnspython() -> bool: try: import dns # noqa: F401 + dns_version, required_version, is_valid = check_for_min_version("dnspython") + if not is_valid: + raise RuntimeError( + f"pymongo requires dnspython>={required_version}, " + f"found version {dns_version}. " + "Install a compatible version with pip" + ) + return True except ImportError: return False @@ -80,6 +88,8 @@ def __init__( srv_service_name: str, srv_max_hosts: int = 0, ): + # Ensure the version of dnspython is compatible. + _have_dnspython() self.__fqdn = fqdn self.__srv = srv_service_name self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT diff --git a/test/version.py b/test/version.py index 42d53cfcf4..ae6ecb331f 100644 --- a/test/version.py +++ b/test/version.py @@ -15,64 +15,10 @@ """Some tools for running tests based on MongoDB server version.""" from __future__ import annotations +from pymongo.common import Version as BaseVersion -class Version(tuple): - def __new__(cls, *version): - padded_version = cls._padded(version, 4) - return super().__new__(cls, tuple(padded_version)) - - @classmethod - def _padded(cls, iter, length, padding=0): - l = list(iter) - if len(l) < length: - for _ in range(length - len(l)): - l.append(padding) - return l - - @classmethod - def from_string(cls, version_string): - mod = 0 - bump_patch_level = False - if version_string.endswith("+"): - version_string = version_string[0:-1] - mod = 1 - elif version_string.endswith("-pre-"): - version_string = version_string[0:-5] - mod = -1 - elif version_string.endswith("-"): - version_string = version_string[0:-1] - mod = -1 - # Deal with '-rcX' substrings - if "-rc" in version_string: - version_string = version_string[0 : version_string.find("-rc")] - mod = -1 - # Deal with git describe generated substrings - elif "-" in version_string: - version_string = version_string[0 : version_string.find("-")] - mod = -1 - bump_patch_level = True - - version = [int(part) for part in version_string.split(".")] - version = cls._padded(version, 3) - # Make from_string and from_version_array agree. For example: - # MongoDB Enterprise > db.runCommand('buildInfo').versionArray - # [ 3, 2, 1, -100 ] - # MongoDB Enterprise > db.runCommand('buildInfo').version - # 3.2.0-97-g1ef94fe - if bump_patch_level: - version[-1] += 1 - version.append(mod) - - return Version(*version) - - @classmethod - def from_version_array(cls, version_array): - version = list(version_array) - if version[-1] < 0: - version[-1] = -1 - version = cls._padded(version, 3) - return Version(*version) +class Version(BaseVersion): @classmethod def from_client(cls, client): info = client.server_info() @@ -86,9 +32,3 @@ async def async_from_client(cls, client): if "versionArray" in info: return cls.from_version_array(info["versionArray"]) return cls.from_string(info["version"]) - - def at_least(self, *other_version): - return self >= Version(*other_version) - - def __str__(self): - return ".".join(map(str, self))