Skip to content
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .evergreen/run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ if [ -z "$GREEN_FRAMEWORK" ]; then
# Use --capture=tee-sys so pytest prints test output inline:
# https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html
if [ -z "$TEST_SUITES" ]; then
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS
python -m pytest -v --capture=tee-sys --durations=5 $TEST_ARGS
else
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 -m $TEST_SUITES $TEST_ARGS
python -m pytest -v --capture=tee-sys --durations=5 -m $TEST_SUITES $TEST_ARGS
fi
else
python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS
Expand Down
62 changes: 48 additions & 14 deletions pymongo/asynchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,37 @@ def __init__(
self.opts = opts
self._spawned = False

async def _async_kms_request(
self,
kms_context: MongoCryptKmsContext,
host: str,
port: Optional[int],
opts: PoolOptions,
message: bytes,
) -> None:
from pymongo.network_layer import async_receive_data_socket # type: ignore[attr-defined]

try:
conn = await _configured_socket((host, port), opts)
try:
await async_sendall(conn, message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
kms_context.feed(data)
except OSError as err:
raise OSError("KMS connection closed") from err
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
finally:
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)

async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
"""Complete a KMS request.

Expand Down Expand Up @@ -174,20 +205,23 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
)
host, port = parse_host(endpoint, _HTTPS_PORT)
try:
conn = await _configured_socket((host, port), opts)
try:
await async_sendall(conn, message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = conn.recv(kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
finally:
conn.close()
if _IS_SYNC:
conn = await _configured_socket((host, port), opts)
try:
await async_sendall(conn, message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = conn.recv(kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
finally:
conn.close()
else:
await self._async_kms_request(kms_context, host, port, opts, message)
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except Exception as error:
Expand Down
26 changes: 25 additions & 1 deletion pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _is_ready(fut: Future) -> None:
loop.remove_writer(fd)

async def _async_receive_ssl(
conn: _sslConn, length: int, loop: AbstractEventLoop
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
) -> memoryview:
mv = memoryview(bytearray(length))
total_read = 0
Expand All @@ -145,6 +145,8 @@ def _is_ready(fut: Future) -> None:
read = conn.recv_into(mv[total_read:])
if read == 0:
raise OSError("connection closed")
if once:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only once?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to support KMS requests, which require that the read returns the first batch immediately to accurately update the expected length.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't sound right. If we know how much we need to read we should be looping until we read that amount.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop inside _async_kms_request changes kms_context.bytes_needed with each recv and feed call. If we allow async_receive_data_socket to run until the initial kms_context.bytes_needed number of bytes is read, it will result in an error.

while kms_context.bytes_needed > 0:
    # CSOT: update timeout.
    conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
    data = await async_receive_data_socket(conn, kms_context.bytes_needed)
    kms_context.feed(data)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the behavior I observed that led to this change:

kms_context.bytes_needed = N
async_receive_data_socket returns with X bytes read, X < N
kms_context.feed is called
kms_context.bytes_needed is now not N - X, but a different value

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, looks like that's intentional behavior in libmongocrypt:

/* Before we've read the Content-Length header in an HTTP response,
 * we don't know how many bytes we'll need. So return this value
 * in kms_ctx_bytes_needed until we are fed the Content-Length.
 */
#define DEFAULT_MAX_KMS_BYTE_REQUEST 1024

https://github.com/mongodb/libmongocrypt/blob/7aeaec4/src/mongocrypt-kms-ctx.c#L47-L51

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a little odd but this approach allows us to reuse as much code as possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add Shane's note as a comment for future readers?

return mv[:read]
total_read += read
except BLOCKING_IO_ERRORS as exc:
fd = conn.fileno()
Expand Down Expand Up @@ -275,6 +277,28 @@ async def async_receive_data(
sock.settimeout(sock_timeout)


async def async_receive_data_socket(
sock: Union[socket.socket, _sslConn], length: int
) -> memoryview:
sock_timeout = sock.gettimeout()
timeout = sock_timeout

sock.settimeout(0.0)
loop = asyncio.get_event_loop()
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
return await asyncio.wait_for(
_async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
timeout=timeout,
)
else:
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
except asyncio.TimeoutError as err:
raise socket.timeout("timed out") from err
finally:
sock.settimeout(sock_timeout)


async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
mv = memoryview(bytearray(length))
bytes_read = 0
Expand Down
62 changes: 48 additions & 14 deletions pymongo/synchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,37 @@ def __init__(
self.opts = opts
self._spawned = False

def _async_kms_request(
self,
kms_context: MongoCryptKmsContext,
host: str,
port: Optional[int],
opts: PoolOptions,
message: bytes,
) -> None:
from pymongo.network_layer import receive_data_socket # type: ignore[attr-defined]

try:
conn = _configured_socket((host, port), opts)
try:
sendall(conn, message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = receive_data_socket(conn, kms_context.bytes_needed)
kms_context.feed(data)
except OSError as err:
raise OSError("KMS connection closed") from err
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
finally:
conn.close()
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except Exception as error:
# Wrap I/O errors in PyMongo exceptions.
_raise_connection_failure((host, port), error)

def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
"""Complete a KMS request.

Expand Down Expand Up @@ -174,20 +205,23 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
)
host, port = parse_host(endpoint, _HTTPS_PORT)
try:
conn = _configured_socket((host, port), opts)
try:
sendall(conn, message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = conn.recv(kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
finally:
conn.close()
if _IS_SYNC:
conn = _configured_socket((host, port), opts)
try:
sendall(conn, message)
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = conn.recv(kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
finally:
conn.close()
else:
self._async_kms_request(kms_context, host, port, opts, message)
except (PyMongoError, MongoCryptError):
raise # Propagate pymongo errors directly.
except Exception as error:
Expand Down
11 changes: 6 additions & 5 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,12 @@ def wrap(*args, **kwargs):
if not self.connected:
pair = self.pair
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
if iscoroutinefunction(condition) and condition():
if wraps_async:
return f(*args, **kwargs)
else:
return f(*args, **kwargs)
if iscoroutinefunction(condition):
if condition():
if wraps_async:
return f(*args, **kwargs)
else:
return f(*args, **kwargs)
elif condition():
if wraps_async:
return f(*args, **kwargs)
Expand Down
11 changes: 6 additions & 5 deletions test/asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,12 @@ async def wrap(*args, **kwargs):
if not self.connected:
pair = await self.pair
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
if iscoroutinefunction(condition) and await condition():
if wraps_async:
return await f(*args, **kwargs)
else:
return f(*args, **kwargs)
if iscoroutinefunction(condition):
if await condition():
if wraps_async:
return await f(*args, **kwargs)
else:
return f(*args, **kwargs)
elif condition():
if wraps_async:
return await f(*args, **kwargs)
Expand Down
Loading
Loading