Skip to content
64 changes: 46 additions & 18 deletions pymongo/asynchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import contextlib
import enum
import socket
import time as time # noqa: PLC0414 # needed in sync version
import uuid
import weakref
from copy import deepcopy
Expand Down Expand Up @@ -63,7 +64,11 @@
from pymongo.asynchronous.cursor import AsyncCursor
from pymongo.asynchronous.database import AsyncDatabase
from pymongo.asynchronous.mongo_client import AsyncMongoClient
from pymongo.asynchronous.pool import _configured_socket, _raise_connection_failure
from pymongo.asynchronous.pool import (
_configured_socket,
_get_timeout_details,
_raise_connection_failure,
)
from pymongo.common import CONNECT_TIMEOUT
from pymongo.daemon import _spawn_daemon
from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts
Expand All @@ -72,7 +77,7 @@
EncryptedCollectionError,
EncryptionError,
InvalidOperation,
PyMongoError,
NetworkTimeout,
ServerSelectionTimeoutError,
)
from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall
Expand All @@ -88,6 +93,9 @@
if TYPE_CHECKING:
from pymongocrypt.mongocrypt import MongoCryptKmsContext

from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address


_IS_SYNC = False

Expand All @@ -103,6 +111,13 @@
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)


async def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
try:
return await _configured_socket(address, opts)
except Exception as exc:
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))


@contextlib.contextmanager
def _wrap_encryption_errors() -> Iterator[None]:
"""Context manager to wrap encryption related errors."""
Expand Down Expand Up @@ -166,18 +181,22 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
None, # crlfile
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False,
) # disable_ocsp_endpoint_check
False, # disable_ocsp_endpoint_check
)
# CSOT: set timeout for socket creation.
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
opts = PoolOptions(
connect_timeout=connect_timeout,
socket_timeout=connect_timeout,
ssl_context=ctx,
)
host, port = parse_host(endpoint, _HTTPS_PORT)
address = parse_host(endpoint, _HTTPS_PORT)
sleep_u = kms_context.usleep
if sleep_u:
sleep_sec = float(sleep_u) / 1e6
await asyncio.sleep(sleep_sec)
try:
conn = await _configured_socket((host, port), opts)
conn = await _connect_kms(address, opts)
try:
await async_sendall(conn, message)
while kms_context.bytes_needed > 0:
Expand All @@ -194,20 +213,29 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
# Async raises an OSError instead of returning empty bytes
except OSError as err:
raise OSError("KMS connection closed") from err
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
# Wrap I/O errors in PyMongo exceptions.
if isinstance(exc, BLOCKING_IO_ERRORS):
exc = socket.timeout("timed out")
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
finally:
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
remaining = _csot.remaining()
if isinstance(exc, NetworkTimeout) or (remaining is not None and remaining <= 0):
raise
# Mark this attempt as failed and defer to libmongocrypt to retry.
try:
kms_context.fail()
except MongoCryptError as final_err:
exc = MongoCryptError(
f"{final_err}, last attempt failed with: {exc}", final_err.code
)
raise exc from final_err

async def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.
Expand Down
64 changes: 46 additions & 18 deletions pymongo/synchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import contextlib
import enum
import socket
import time as time # noqa: PLC0414 # needed in sync version
import uuid
import weakref
from copy import deepcopy
Expand Down Expand Up @@ -67,7 +68,7 @@
EncryptedCollectionError,
EncryptionError,
InvalidOperation,
PyMongoError,
NetworkTimeout,
ServerSelectionTimeoutError,
)
from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall
Expand All @@ -80,14 +81,21 @@
from pymongo.synchronous.cursor import Cursor
from pymongo.synchronous.database import Database
from pymongo.synchronous.mongo_client import MongoClient
from pymongo.synchronous.pool import _configured_socket, _raise_connection_failure
from pymongo.synchronous.pool import (
_configured_socket,
_get_timeout_details,
_raise_connection_failure,
)
from pymongo.typings import _DocumentType, _DocumentTypeArg
from pymongo.uri_parser import parse_host
from pymongo.write_concern import WriteConcern

if TYPE_CHECKING:
from pymongocrypt.mongocrypt import MongoCryptKmsContext

from pymongo.pyopenssl_context import _sslConn
from pymongo.typings import _Address


_IS_SYNC = True

Expand All @@ -103,6 +111,13 @@
_KEY_VAULT_OPTS = CodecOptions(document_class=RawBSONDocument)


def _connect_kms(address: _Address, opts: PoolOptions) -> Union[socket.socket, _sslConn]:
try:
return _configured_socket(address, opts)
except Exception as exc:
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))


@contextlib.contextmanager
def _wrap_encryption_errors() -> Iterator[None]:
"""Context manager to wrap encryption related errors."""
Expand Down Expand Up @@ -166,18 +181,22 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
None, # crlfile
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False,
) # disable_ocsp_endpoint_check
False, # disable_ocsp_endpoint_check
)
# CSOT: set timeout for socket creation.
connect_timeout = max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0.001)
opts = PoolOptions(
connect_timeout=connect_timeout,
socket_timeout=connect_timeout,
ssl_context=ctx,
)
host, port = parse_host(endpoint, _HTTPS_PORT)
address = parse_host(endpoint, _HTTPS_PORT)
sleep_u = kms_context.usleep
if sleep_u:
sleep_sec = float(sleep_u) / 1e6
time.sleep(sleep_sec)
try:
conn = _configured_socket((host, port), opts)
conn = _connect_kms(address, opts)
try:
sendall(conn, message)
while kms_context.bytes_needed > 0:
Expand All @@ -194,20 +213,29 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
# Async raises an OSError instead of returning empty bytes
except OSError as err:
raise OSError("KMS connection closed") from err
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
# Wrap I/O errors in PyMongo exceptions.
if isinstance(exc, BLOCKING_IO_ERRORS):
exc = socket.timeout("timed out")
_raise_connection_failure(address, exc, timeout_details=_get_timeout_details(opts))
finally:
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except asyncio.CancelledError:
raise
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)
except MongoCryptError:
raise # Propagate MongoCryptError errors directly.
except Exception as exc:
remaining = _csot.remaining()
if isinstance(exc, NetworkTimeout) or (remaining is not None and remaining <= 0):
raise
# Mark this attempt as failed and defer to libmongocrypt to retry.
try:
kms_context.fail()
except MongoCryptError as final_err:
exc = MongoCryptError(
f"{final_err}, last attempt failed with: {exc}", final_err.code
)
raise exc from final_err

def collection_info(self, database: str, filter: bytes) -> Optional[bytes]:
"""Get the collection info for a namespace.
Expand Down
86 changes: 84 additions & 2 deletions test/asynchronous/test_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import base64
import copy
import http.client
import json
import os
import pathlib
import re
Expand Down Expand Up @@ -91,6 +93,7 @@
WriteError,
)
from pymongo.operations import InsertOne, ReplaceOne, UpdateOne
from pymongo.ssl_support import get_ssl_context
from pymongo.write_concern import WriteConcern

_IS_SYNC = False
Expand Down Expand Up @@ -1366,9 +1369,8 @@ async def test_04_aws_endpoint_invalid_port(self):
"key": ("arn:aws:kms:us-east-1:579766882180:key/89fcc2c4-08b0-4bd9-9f25-e30687b580d0"),
"endpoint": "kms.us-east-1.amazonaws.com:12345",
}
with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345") as ctx:
with self.assertRaisesRegex(EncryptionError, "kms.us-east-1.amazonaws.com:12345"):
await self.client_encryption.create_data_key("aws", master_key=master_key)
self.assertIsInstance(ctx.exception.cause, AutoReconnect)

@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
async def test_05_aws_endpoint_wrong_region(self):
Expand Down Expand Up @@ -2853,6 +2855,86 @@ async def test_accepts_trim_factor_0(self):
assert len(payload) > len(self.payload_defaults)


# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#24-kms-retry-tests
class TestKmsRetryProse(AsyncEncryptionIntegrationTest):
@unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set")
async def asyncSetUp(self):
await super().asyncSetUp()
# 1, create client with only tlsCAFile.
providers: dict = copy.deepcopy(ALL_KMS_PROVIDERS)
providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9003"
providers["gcp"]["endpoint"] = "127.0.0.1:9003"
kms_tls_opts = {
p: {"tlsCAFile": CA_PEM, "tlsCertificateKeyFile": CLIENT_PEM} for p in providers
}
self.client_encryption = self.create_client_encryption(
providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts
)

async def http_post(self, path, data=None):
# Note, the connection to the mock server needs to be closed after
# each request because the server is single threaded.
ctx: ssl.SSLContext = get_ssl_context(
CLIENT_PEM, # certfile
None, # passphrase
CA_PEM, # ca_certs
None, # crlfile
False, # allow_invalid_certificates
False, # allow_invalid_hostnames
False, # disable_ocsp_endpoint_check
)
conn = http.client.HTTPSConnection("127.0.0.1:9003", context=ctx)
try:
if data is not None:
headers = {"Content-type": "application/json"}
body = json.dumps(data)
else:
headers = {}
body = None
conn.request("POST", path, body, headers)
res = conn.getresponse()
res.read()
finally:
conn.close()

async def _test(self, provider, master_key):
await self.http_post("/reset")
# Case 1: createDataKey and encrypt with TCP retry
await self.http_post("/set_failpoint/network", {"count": 1})
key_id = await self.client_encryption.create_data_key(provider, master_key=master_key)
await self.http_post("/set_failpoint/network", {"count": 1})
await self.client_encryption.encrypt(
123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id
)

# Case 2: createDataKey and encrypt with HTTP retry
await self.http_post("/set_failpoint/http", {"count": 1})
key_id = await self.client_encryption.create_data_key(provider, master_key=master_key)
await self.http_post("/set_failpoint/http", {"count": 1})
await self.client_encryption.encrypt(
123, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id
)

# Case 3: createDataKey fails after too many retries
await self.http_post("/set_failpoint/network", {"count": 4})
with self.assertRaisesRegex(EncryptionError, "KMS request failed after"):
await self.client_encryption.create_data_key(provider, master_key=master_key)

async def test_kms_retry(self):
await self._test("aws", {"region": "foo", "key": "bar", "endpoint": "127.0.0.1:9003"})
await self._test("azure", {"keyVaultEndpoint": "127.0.0.1:9003", "keyName": "foo"})
await self._test(
"gcp",
{
"projectId": "foo",
"location": "bar",
"keyRing": "baz",
"keyName": "qux",
"endpoint": "127.0.0.1:9003",
},
)


# https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.md#automatic-data-encryption-keys
class TestAutomaticDecryptionKeys(AsyncEncryptionIntegrationTest):
@async_client_context.require_no_standalone
Expand Down
Loading
Loading