Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .evergreen/run-tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -257,9 +257,9 @@ if [ -z "$GREEN_FRAMEWORK" ]; then
# Use --capture=tee-sys so pytest prints test output inline:
# https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html
if [ -z "$TEST_SUITES" ]; then
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS
python -m pytest -v --capture=tee-sys --durations=5 $TEST_ARGS
else
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 -m $TEST_SUITES $TEST_ARGS
python -m pytest -v --capture=tee-sys --durations=5 -m $TEST_SUITES $TEST_ARGS
fi
else
python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS
Expand Down
12 changes: 11 additions & 1 deletion pymongo/asynchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,20 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = conn.recv(kms_context.bytes_needed)
if _IS_SYNC:
data = conn.recv(kms_context.bytes_needed)
else:
from pymongo.network_layer import ( # type: ignore[attr-defined]
async_receive_data_socket,
)

data = await async_receive_data_socket(conn, kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
# Async raises an OSError instead of returning empty bytes
except OSError as err:
raise OSError("KMS connection closed") from err
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
finally:
Expand Down
27 changes: 26 additions & 1 deletion pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def _is_ready(fut: Future) -> None:
loop.remove_writer(fd)

async def _async_receive_ssl(
conn: _sslConn, length: int, loop: AbstractEventLoop
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
) -> memoryview:
mv = memoryview(bytearray(length))
total_read = 0
Expand All @@ -145,6 +145,9 @@ def _is_ready(fut: Future) -> None:
read = conn.recv_into(mv[total_read:])
if read == 0:
raise OSError("connection closed")
# KMS responses update their expected size after the first batch, stop reading after one loop
if once:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why only once?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is to support KMS requests, which require that the read returns the first batch immediately to accurately update the expected length.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't sound right. If we know how much we need to read we should be looping until we read that amount.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This loop inside _async_kms_request changes kms_context.bytes_needed with each recv and feed call. If we allow async_receive_data_socket to run until the initial kms_context.bytes_needed number of bytes is read, it will result in an error.

while kms_context.bytes_needed > 0:
    # CSOT: update timeout.
    conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
    data = await async_receive_data_socket(conn, kms_context.bytes_needed)
    kms_context.feed(data)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the behavior I observed that led to this change:

kms_context.bytes_needed = N
async_receive_data_socket returns with X bytes read, X < N
kms_context.feed is called
kms_context.bytes_needed is now not N - X, but a different value

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, looks like that's intentional behavior in libmongocrypt:

/* Before we've read the Content-Length header in an HTTP response,
 * we don't know how many bytes we'll need. So return this value
 * in kms_ctx_bytes_needed until we are fed the Content-Length.
 */
#define DEFAULT_MAX_KMS_BYTE_REQUEST 1024

https://github.com/mongodb/libmongocrypt/blob/7aeaec4/src/mongocrypt-kms-ctx.c#L47-L51

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a little odd but this approach allows us to reuse as much code as possible.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add Shane's note as a comment for future readers?

return mv[:read]
total_read += read
except BLOCKING_IO_ERRORS as exc:
fd = conn.fileno()
Expand Down Expand Up @@ -275,6 +278,28 @@ async def async_receive_data(
sock.settimeout(sock_timeout)


async def async_receive_data_socket(
sock: Union[socket.socket, _sslConn], length: int
) -> memoryview:
sock_timeout = sock.gettimeout()
timeout = sock_timeout

sock.settimeout(0.0)
loop = asyncio.get_event_loop()
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
return await asyncio.wait_for(
_async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
timeout=timeout,
)
else:
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
except asyncio.TimeoutError as err:
raise socket.timeout("timed out") from err
finally:
sock.settimeout(sock_timeout)


async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
mv = memoryview(bytearray(length))
bytes_read = 0
Expand Down
12 changes: 11 additions & 1 deletion pymongo/synchronous/encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,20 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = conn.recv(kms_context.bytes_needed)
if _IS_SYNC:
data = conn.recv(kms_context.bytes_needed)
else:
from pymongo.network_layer import ( # type: ignore[attr-defined]
receive_data_socket,
)

data = receive_data_socket(conn, kms_context.bytes_needed)
if not data:
raise OSError("KMS connection closed")
kms_context.feed(data)
# Async raises an OSError instead of returning empty bytes
except OSError as err:
raise OSError("KMS connection closed") from err
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
finally:
Expand Down
11 changes: 6 additions & 5 deletions test/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,11 +464,12 @@ def wrap(*args, **kwargs):
if not self.connected:
pair = self.pair
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
if iscoroutinefunction(condition) and condition():
if wraps_async:
return f(*args, **kwargs)
else:
return f(*args, **kwargs)
if iscoroutinefunction(condition):
if condition():
if wraps_async:
return f(*args, **kwargs)
else:
return f(*args, **kwargs)
elif condition():
if wraps_async:
return f(*args, **kwargs)
Expand Down
11 changes: 6 additions & 5 deletions test/asynchronous/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,11 +466,12 @@ async def wrap(*args, **kwargs):
if not self.connected:
pair = await self.pair
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
if iscoroutinefunction(condition) and await condition():
if wraps_async:
return await f(*args, **kwargs)
else:
return f(*args, **kwargs)
if iscoroutinefunction(condition):
if await condition():
if wraps_async:
return await f(*args, **kwargs)
else:
return f(*args, **kwargs)
elif condition():
if wraps_async:
return await f(*args, **kwargs)
Expand Down
Loading
Loading