Skip to content

Commit 861dbb5

Browse files
committed
Address review
1 parent a0a85c5 commit 861dbb5

File tree

6 files changed

+52
-24
lines changed

6 files changed

+52
-24
lines changed

pymongo/asynchronous/mongo_client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1566,8 +1566,13 @@ async def close(self) -> None:
15661566
await self._encrypter.close()
15671567
self._closed = True
15681568
if not _IS_SYNC:
1569-
self._topology._monitor_tasks.append(self._kill_cursors_executor) # type: ignore[arg-type]
1570-
join_tasks = [t.join() for t in self._topology._monitor_tasks] # type: ignore[func-returns-value]
1569+
join_tasks = [self._kill_cursors_executor]
1570+
try:
1571+
while self._topology._monitor_tasks:
1572+
join_tasks.append(self._topology._monitor_tasks.pop())
1573+
except IndexError:
1574+
pass
1575+
join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value]
15711576
await asyncio.gather(*join_tasks)
15721577

15731578
if not _IS_SYNC:

pymongo/asynchronous/monitor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ def gc_safe_close(self) -> None:
190190
self.cancel_check()
191191

192192
async def join(self) -> None:
193-
await self._executor.join()
194-
await self._rtt_monitor.join()
193+
await asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()])
195194

196195
async def close(self) -> None:
197196
self.gc_safe_close()

pymongo/asynchronous/topology.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ async def open(self) -> None:
245245
# Close servers and clear the pools.
246246
for server in self._servers.values():
247247
await server.close()
248-
self._monitor_tasks.append(server._monitor)
248+
if not _IS_SYNC:
249+
self._monitor_tasks.append(server._monitor)
249250
# Reset the session pool to avoid duplicate sessions in
250251
# the child process.
251252
self._session_pool.reset()
@@ -298,9 +299,14 @@ async def select_servers(
298299
]
299300

300301
if not _IS_SYNC and self._monitor_tasks:
301-
joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value]
302-
await asyncio.gather(*joins)
303-
self._monitor_tasks = []
302+
join_tasks = []
303+
try:
304+
while self._monitor_tasks:
305+
join_tasks.append(self._monitor_tasks.pop())
306+
except IndexError:
307+
pass
308+
join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value]
309+
await asyncio.gather(*join_tasks)
304310

305311
return servers
306312

@@ -532,7 +538,8 @@ async def _process_change(
532538
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
533539
):
534540
await self._srv_monitor.close()
535-
self._monitor_tasks.append(self._srv_monitor)
541+
if not _IS_SYNC:
542+
self._monitor_tasks.append(self._srv_monitor)
536543

537544
# Clear the pool from a failed heartbeat.
538545
if reset_pool:
@@ -708,7 +715,8 @@ async def close(self) -> None:
708715
old_td = self._description
709716
for server in self._servers.values():
710717
await server.close()
711-
self._monitor_tasks.append(server._monitor)
718+
if not _IS_SYNC:
719+
self._monitor_tasks.append(server._monitor)
712720

713721
# Mark all servers Unknown.
714722
self._description = self._description.reset()
@@ -719,7 +727,8 @@ async def close(self) -> None:
719727
# Stop SRV polling thread.
720728
if self._srv_monitor:
721729
await self._srv_monitor.close()
722-
self._monitor_tasks.append(self._srv_monitor)
730+
if not _IS_SYNC:
731+
self._monitor_tasks.append(self._srv_monitor)
723732

724733
self._opened = False
725734
self._closed = True
@@ -959,7 +968,8 @@ async def _update_servers(self) -> None:
959968
for address, server in list(self._servers.items()):
960969
if not self._description.has_server(address):
961970
await server.close()
962-
self._monitor_tasks.append(server._monitor)
971+
if not _IS_SYNC:
972+
self._monitor_tasks.append(server._monitor)
963973
self._servers.pop(address)
964974

965975
def _create_pool_for_server(self, address: _Address) -> Pool:

pymongo/synchronous/mongo_client.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1560,8 +1560,13 @@ def close(self) -> None:
15601560
self._encrypter.close()
15611561
self._closed = True
15621562
if not _IS_SYNC:
1563-
self._topology._monitor_tasks.append(self._kill_cursors_executor) # type: ignore[arg-type]
1564-
join_tasks = [t.join() for t in self._topology._monitor_tasks] # type: ignore[func-returns-value]
1563+
join_tasks = [self._kill_cursors_executor]
1564+
try:
1565+
while self._topology._monitor_tasks:
1566+
join_tasks.append(self._topology._monitor_tasks.pop())
1567+
except IndexError:
1568+
pass
1569+
join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value]
15651570
asyncio.gather(*join_tasks)
15661571

15671572
if not _IS_SYNC:

pymongo/synchronous/monitor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,8 +190,7 @@ def gc_safe_close(self) -> None:
190190
self.cancel_check()
191191

192192
def join(self) -> None:
193-
self._executor.join()
194-
self._rtt_monitor.join()
193+
asyncio.gather(*[self._executor.join(), self._rtt_monitor.join()])
195194

196195
def close(self) -> None:
197196
self.gc_safe_close()

pymongo/synchronous/topology.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ def open(self) -> None:
245245
# Close servers and clear the pools.
246246
for server in self._servers.values():
247247
server.close()
248-
self._monitor_tasks.append(server._monitor)
248+
if not _IS_SYNC:
249+
self._monitor_tasks.append(server._monitor)
249250
# Reset the session pool to avoid duplicate sessions in
250251
# the child process.
251252
self._session_pool.reset()
@@ -298,9 +299,14 @@ def select_servers(
298299
]
299300

300301
if not _IS_SYNC and self._monitor_tasks:
301-
joins = [t.join() for t in self._monitor_tasks] # type: ignore[func-returns-value]
302-
asyncio.gather(*joins)
303-
self._monitor_tasks = []
302+
join_tasks = []
303+
try:
304+
while self._monitor_tasks:
305+
join_tasks.append(self._monitor_tasks.pop())
306+
except IndexError:
307+
pass
308+
join_tasks = [t.join() for t in join_tasks] # type: ignore[func-returns-value]
309+
asyncio.gather(*join_tasks)
304310

305311
return servers
306312

@@ -532,7 +538,8 @@ def _process_change(
532538
and self._description.topology_type not in SRV_POLLING_TOPOLOGIES
533539
):
534540
self._srv_monitor.close()
535-
self._monitor_tasks.append(self._srv_monitor)
541+
if not _IS_SYNC:
542+
self._monitor_tasks.append(self._srv_monitor)
536543

537544
# Clear the pool from a failed heartbeat.
538545
if reset_pool:
@@ -706,7 +713,8 @@ def close(self) -> None:
706713
old_td = self._description
707714
for server in self._servers.values():
708715
server.close()
709-
self._monitor_tasks.append(server._monitor)
716+
if not _IS_SYNC:
717+
self._monitor_tasks.append(server._monitor)
710718

711719
# Mark all servers Unknown.
712720
self._description = self._description.reset()
@@ -717,7 +725,8 @@ def close(self) -> None:
717725
# Stop SRV polling thread.
718726
if self._srv_monitor:
719727
self._srv_monitor.close()
720-
self._monitor_tasks.append(self._srv_monitor)
728+
if not _IS_SYNC:
729+
self._monitor_tasks.append(self._srv_monitor)
721730

722731
self._opened = False
723732
self._closed = True
@@ -957,7 +966,8 @@ def _update_servers(self) -> None:
957966
for address, server in list(self._servers.items()):
958967
if not self._description.has_server(address):
959968
server.close()
960-
self._monitor_tasks.append(server._monitor)
969+
if not _IS_SYNC:
970+
self._monitor_tasks.append(server._monitor)
961971
self._servers.pop(address)
962972

963973
def _create_pool_for_server(self, address: _Address) -> Pool:

0 commit comments

Comments
 (0)