Skip to content

Commit 3a66229

Browse files
authored
PYTHON-4700 - Convert CSFLE tests to async (mongodb#1907)
1 parent 8118aea commit 3a66229

File tree

15 files changed

+655
-434
lines changed

15 files changed

+655
-434
lines changed

.evergreen/run-tests.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ if [ -z "$GREEN_FRAMEWORK" ]; then
257257
# Use --capture=tee-sys so pytest prints test output inline:
258258
# https://docs.pytest.org/en/stable/how-to/capture-stdout-stderr.html
259259
if [ -z "$TEST_SUITES" ]; then
260-
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 $TEST_ARGS
260+
python -m pytest -v --capture=tee-sys --durations=5 $TEST_ARGS
261261
else
262-
python -m pytest -v --capture=tee-sys --durations=5 --maxfail=10 -m $TEST_SUITES $TEST_ARGS
262+
python -m pytest -v --capture=tee-sys --durations=5 -m $TEST_SUITES $TEST_ARGS
263263
fi
264264
else
265265
python green_framework_test.py $GREEN_FRAMEWORK -v $TEST_ARGS

pymongo/asynchronous/encryption.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,20 @@ async def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
180180
while kms_context.bytes_needed > 0:
181181
# CSOT: update timeout.
182182
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
183-
data = conn.recv(kms_context.bytes_needed)
183+
if _IS_SYNC:
184+
data = conn.recv(kms_context.bytes_needed)
185+
else:
186+
from pymongo.network_layer import ( # type: ignore[attr-defined]
187+
async_receive_data_socket,
188+
)
189+
190+
data = await async_receive_data_socket(conn, kms_context.bytes_needed)
184191
if not data:
185192
raise OSError("KMS connection closed")
186193
kms_context.feed(data)
194+
# Async raises an OSError instead of returning empty bytes
195+
except OSError as err:
196+
raise OSError("KMS connection closed") from err
187197
except BLOCKING_IO_ERRORS:
188198
raise socket.timeout("timed out") from None
189199
finally:

pymongo/network_layer.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _is_ready(fut: Future) -> None:
130130
loop.remove_writer(fd)
131131

132132
async def _async_receive_ssl(
133-
conn: _sslConn, length: int, loop: AbstractEventLoop
133+
conn: _sslConn, length: int, loop: AbstractEventLoop, once: Optional[bool] = False
134134
) -> memoryview:
135135
mv = memoryview(bytearray(length))
136136
total_read = 0
@@ -145,6 +145,9 @@ def _is_ready(fut: Future) -> None:
145145
read = conn.recv_into(mv[total_read:])
146146
if read == 0:
147147
raise OSError("connection closed")
148+
# KMS responses update their expected size after the first batch, stop reading after one loop
149+
if once:
150+
return mv[:read]
148151
total_read += read
149152
except BLOCKING_IO_ERRORS as exc:
150153
fd = conn.fileno()
@@ -275,6 +278,28 @@ async def async_receive_data(
275278
sock.settimeout(sock_timeout)
276279

277280

281+
async def async_receive_data_socket(
282+
sock: Union[socket.socket, _sslConn], length: int
283+
) -> memoryview:
284+
sock_timeout = sock.gettimeout()
285+
timeout = sock_timeout
286+
287+
sock.settimeout(0.0)
288+
loop = asyncio.get_event_loop()
289+
try:
290+
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
291+
return await asyncio.wait_for(
292+
_async_receive_ssl(sock, length, loop, once=True), # type: ignore[arg-type]
293+
timeout=timeout,
294+
)
295+
else:
296+
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
297+
except asyncio.TimeoutError as err:
298+
raise socket.timeout("timed out") from err
299+
finally:
300+
sock.settimeout(sock_timeout)
301+
302+
278303
async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
279304
mv = memoryview(bytearray(length))
280305
bytes_read = 0

pymongo/synchronous/encryption.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,10 +180,20 @@ def kms_request(self, kms_context: MongoCryptKmsContext) -> None:
180180
while kms_context.bytes_needed > 0:
181181
# CSOT: update timeout.
182182
conn.settimeout(max(_csot.clamp_remaining(_KMS_CONNECT_TIMEOUT), 0))
183-
data = conn.recv(kms_context.bytes_needed)
183+
if _IS_SYNC:
184+
data = conn.recv(kms_context.bytes_needed)
185+
else:
186+
from pymongo.network_layer import ( # type: ignore[attr-defined]
187+
receive_data_socket,
188+
)
189+
190+
data = receive_data_socket(conn, kms_context.bytes_needed)
184191
if not data:
185192
raise OSError("KMS connection closed")
186193
kms_context.feed(data)
194+
# Async raises an OSError instead of returning empty bytes
195+
except OSError as err:
196+
raise OSError("KMS connection closed") from err
187197
except BLOCKING_IO_ERRORS:
188198
raise socket.timeout("timed out") from None
189199
finally:

test/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -464,11 +464,12 @@ def wrap(*args, **kwargs):
464464
if not self.connected:
465465
pair = self.pair
466466
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
467-
if iscoroutinefunction(condition) and condition():
468-
if wraps_async:
469-
return f(*args, **kwargs)
470-
else:
471-
return f(*args, **kwargs)
467+
if iscoroutinefunction(condition):
468+
if condition():
469+
if wraps_async:
470+
return f(*args, **kwargs)
471+
else:
472+
return f(*args, **kwargs)
472473
elif condition():
473474
if wraps_async:
474475
return f(*args, **kwargs)

test/asynchronous/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,12 @@ async def wrap(*args, **kwargs):
466466
if not self.connected:
467467
pair = await self.pair
468468
raise SkipTest(f"Cannot connect to MongoDB on {pair}")
469-
if iscoroutinefunction(condition) and await condition():
470-
if wraps_async:
471-
return await f(*args, **kwargs)
472-
else:
473-
return f(*args, **kwargs)
469+
if iscoroutinefunction(condition):
470+
if await condition():
471+
if wraps_async:
472+
return await f(*args, **kwargs)
473+
else:
474+
return f(*args, **kwargs)
474475
elif condition():
475476
if wraps_async:
476477
return await f(*args, **kwargs)

0 commit comments

Comments
 (0)