-
Notifications
You must be signed in to change notification settings - Fork 1.1k
PYTHON-4700 - Convert CSFLE tests to async #1907
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 19 commits
4158326
acc8bdc
e6d4b11
d6a0429
1922702
97f39dd
7a4df29
c589fca
689b0b8
6feec44
17388d1
b472d3a
3bdd381
8e6a30d
8ca4b0e
a229746
88bf8d8
59a95cc
b12abc5
72303d6
2a77285
24ba64a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -130,7 +130,7 @@ def _is_ready(fut: Future) -> None: | |
loop.remove_writer(fd) | ||
|
||
async def _async_receive_ssl( | ||
conn: _sslConn, length: int, loop: AbstractEventLoop | ||
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False | ||
) -> memoryview: | ||
mv = memoryview(bytearray(length)) | ||
total_read = 0 | ||
|
@@ -145,6 +145,8 @@ def _is_ready(fut: Future) -> None: | |
read = conn.recv_into(mv[total_read:]) | ||
if read == 0: | ||
raise OSError("connection closed") | ||
if once: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why only once? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is to support KMS requests, which require that the read returns the first batch immediately to accurately update the expected length. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That doesn't sound right. If we know how much we need to read we should be looping until we read that amount. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This loop inside while kms_context.bytes_needed > 0:
# CSOT: update timeout.
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
kms_context.feed(data) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is the behavior I observed that led to this change: kms_context.bytes_needed = N There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I see, looks like that's intentional behavior in libmongocrypt:
https://github.com/mongodb/libmongocrypt/blob/7aeaec4/src/mongocrypt-kms-ctx.c#L47-L51 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it's a little odd but this approach allows us to reuse as much code as possible. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you add Shane's note as a comment for future readers? |
||
return mv[:read] | ||
total_read += read | ||
except BLOCKING_IO_ERRORS as exc: | ||
fd = conn.fileno() | ||
|
@@ -275,6 +277,28 @@ async def async_receive_data( | |
sock.settimeout(sock_timeout) | ||
|
||
|
||
async def async_receive_data_socket( | ||
sock: Union[socket.socket, _sslConn], length: int | ||
) -> memoryview: | ||
sock_timeout = sock.gettimeout() | ||
timeout = sock_timeout | ||
|
||
sock.settimeout(0.0) | ||
loop = asyncio.get_event_loop() | ||
try: | ||
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)): | ||
return await asyncio.wait_for( | ||
_async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type] | ||
timeout=timeout, | ||
) | ||
else: | ||
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type] | ||
except asyncio.TimeoutError as err: | ||
raise socket.timeout("timed out") from err | ||
finally: | ||
sock.settimeout(sock_timeout) | ||
|
||
|
||
async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview: | ||
mv = memoryview(bytearray(length)) | ||
bytes_read = 0 | ||
|
Uh oh!
There was an error while loading. Please reload this page.