Skip to content

Commit 2ac4ea3

Browse files
committed
Address review
1 parent 861dbb5 commit 2ac4ea3

File tree

6 files changed

+36
-44
lines changed

6 files changed

+36
-44
lines changed

pymongo/asynchronous/mongo_client.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,14 +1566,9 @@ async def close(self) -> None:
15661566
await self._encrypter.close()
15671567
self._closed = True
15681568
if not _IS_SYNC:
1569-
join_tasks = [self._kill_cursors_executor]
1570-
try:
1571-
while self._topology._monitor_tasks:
1572-
join_tasks.append(self._topology._monitor_tasks.pop())
1573-
except IndexError:
1574-
pass
1575-
join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value]
1576-
await asyncio.gather(*join_tasks)
1569+
await asyncio.gather(
1570+
*[self._topology.cleanup_monitors(), self._kill_cursors_executor.join()] # type: ignore[func-returns-value]
1571+
)
15771572

15781573
if not _IS_SYNC:
15791574
# Add support for contextlib.aclosing.

pymongo/asynchronous/monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def gc_safe_close(self) -> None:
190190
self.cancel_check()
191191

192192
async def join(self) -> None:
193-
await asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()])
193+
await asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()]) # type: ignore[func-returns-value]
194194

195195
async def close(self) -> None:
196196
self.gc_safe_close()

pymongo/asynchronous/topology.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -289,27 +289,19 @@ async def select_servers(
289289
else:
290290
server_timeout = server_selection_timeout
291291

292+
# Cleanup any completed monitor tasks safely
293+
if not _IS_SYNC and self._monitor_tasks:
294+
await self.cleanup_monitors()
295+
292296
async with self._lock:
293297
server_descriptions = await self._select_servers_loop(
294298
selector, server_timeout, operation, operation_id, address
295299
)
296300

297-
servers = [
301+
return [
298302
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions
299303
]
300304

301-
if not _IS_SYNC and self._monitor_tasks:
302-
join_tasks = []
303-
try:
304-
while self._monitor_tasks:
305-
join_tasks.append(self._monitor_tasks.pop())
306-
except IndexError:
307-
pass
308-
join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value]
309-
await asyncio.gather(*join_tasks)
310-
311-
return servers
312-
313305
async def _select_servers_loop(
314306
self,
315307
selector: Callable[[Selection], Selection],
@@ -1057,6 +1049,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str:
10571049
else:
10581050
return ",".join(str(server.error) for server in servers if server.error)
10591051

1052+
async def cleanup_monitors(self) -> None:
1053+
tasks = []
1054+
try:
1055+
while self._monitor_tasks:
1056+
tasks.append(self._monitor_tasks.pop())
1057+
except IndexError:
1058+
pass
1059+
await asyncio.gather(*[t.join() for t in tasks]) # type: ignore[func-returns-value]
1060+
10601061
def __repr__(self) -> str:
10611062
msg = ""
10621063
if not self._opened:

pymongo/synchronous/mongo_client.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,14 +1560,9 @@ def close(self) -> None:
15601560
self._encrypter.close()
15611561
self._closed = True
15621562
if not _IS_SYNC:
1563-
join_tasks = [self._kill_cursors_executor]
1564-
try:
1565-
while self._topology._monitor_tasks:
1566-
join_tasks.append(self._topology._monitor_tasks.pop())
1567-
except IndexError:
1568-
pass
1569-
join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value]
1570-
asyncio.gather(*join_tasks)
1563+
asyncio.gather(
1564+
*[self._topology.cleanup_monitors(), self._kill_cursors_executor.join()] # type: ignore[func-returns-value]
1565+
)
15711566

15721567
if not _IS_SYNC:
15731568
# Add support for contextlib.closing.

pymongo/synchronous/monitor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def gc_safe_close(self) -> None:
190190
self.cancel_check()
191191

192192
def join(self) -> None:
193-
asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()])
193+
asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()]) # type: ignore[func-returns-value]
194194

195195
def close(self) -> None:
196196
self.gc_safe_close()

pymongo/synchronous/topology.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -289,27 +289,19 @@ def select_servers(
289289
else:
290290
server_timeout = server_selection_timeout
291291

292+
# Cleanup any completed monitor tasks safely
293+
if not _IS_SYNC and self._monitor_tasks:
294+
self.cleanup_monitors()
295+
292296
with self._lock:
293297
server_descriptions = self._select_servers_loop(
294298
selector, server_timeout, operation, operation_id, address
295299
)
296300

297-
servers = [
301+
return [
298302
cast(Server, self.get_server_by_address(sd.address)) for sd in server_descriptions
299303
]
300304

301-
if not _IS_SYNC and self._monitor_tasks:
302-
join_tasks = []
303-
try:
304-
while self._monitor_tasks:
305-
join_tasks.append(self._monitor_tasks.pop())
306-
except IndexError:
307-
pass
308-
join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value]
309-
asyncio.gather(*join_tasks)
310-
311-
return servers
312-
313305
def _select_servers_loop(
314306
self,
315307
selector: Callable[[Selection], Selection],
@@ -1055,6 +1047,15 @@ def _error_message(self, selector: Callable[[Selection], Selection]) -> str:
10551047
else:
10561048
return ",".join(str(server.error) for server in servers if server.error)
10571049

1050+
def cleanup_monitors(self) -> None:
1051+
tasks = []
1052+
try:
1053+
while self._monitor_tasks:
1054+
tasks.append(self._monitor_tasks.pop())
1055+
except IndexError:
1056+
pass
1057+
asyncio.gather(*[t.join() for t in tasks]) # type: ignore[func-returns-value]
1058+
10581059
def __repr__(self) -> str:
10591060
msg = ""
10601061
if not self._opened:

0 commit comments

Comments
 (0)