Skip to content

Commit 615e453

Browse files
committed
<ADD>: add disconnect_all_async
1 parent 7a07dc1 commit 615e453

File tree

2 files changed

+89
-5
lines changed

2 files changed

+89
-5
lines changed

mongoengine/connection.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
"disconnect",
4040
"disconnect_async",
4141
"disconnect_all",
42+
"disconnect_all_async",
4243
"get_connection",
4344
"get_db",
4445
"get_async_db",
@@ -685,11 +686,8 @@ async def disconnect_async(alias=DEFAULT_CONNECTION_NAME):
685686
if connection:
686687
# Only close if this is the last reference to this connection
687688
if all(connection is not c for c in _connections.values()):
688-
if is_async_connection(alias):
689-
# For AsyncMongoClient, we need to call close() method
690-
connection.close()
691-
else:
692-
connection.close()
689+
# AsyncMongoClient.close() is a coroutine, must be awaited
690+
await connection.close()
693691

694692
# Clean up database references
695693
if alias in _dbs:
@@ -708,6 +706,23 @@ async def disconnect_async(alias=DEFAULT_CONNECTION_NAME):
708706
del _connection_types[alias]
709707

710708

709+
async def disconnect_all_async():
710+
"""Close all registered async database connections.
711+
712+
This is the async version of disconnect_all() that properly closes
713+
AsyncMongoClient connections. It will only close async connections,
714+
leaving sync connections untouched.
715+
"""
716+
# Get list of all async connections
717+
async_aliases = [
718+
alias for alias in list(_connections.keys()) if is_async_connection(alias)
719+
]
720+
721+
# Disconnect each async connection
722+
for alias in async_aliases:
723+
await disconnect_async(alias)
724+
725+
711726
def get_async_db(alias=DEFAULT_CONNECTION_NAME, reconnect=False):
712727
"""Get the async database for a given alias.
713728

tests/test_async_connection.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
connect,
88
connect_async,
99
disconnect,
10+
disconnect_all_async,
1011
disconnect_async,
1112
get_async_db,
1213
is_async_connection,
@@ -178,3 +179,71 @@ async def test_reconnect_async_different_settings(self):
178179

179180
# Clean up
180181
await disconnect_async("reconnect_test2")
182+
183+
@pytest.mark.asyncio
184+
async def test_disconnect_all_async(self):
185+
"""Test disconnect_all_async only disconnects async connections."""
186+
# Create mix of sync and async connections
187+
connect(db="sync_db1", alias="sync1")
188+
connect(db="sync_db2", alias="sync2")
189+
await connect_async(db="async_db1", alias="async1")
190+
await connect_async(db="async_db2", alias="async2")
191+
await connect_async(db="async_db3", alias="async3")
192+
193+
# Verify connections exist
194+
assert not is_async_connection("sync1")
195+
assert not is_async_connection("sync2")
196+
assert is_async_connection("async1")
197+
assert is_async_connection("async2")
198+
assert is_async_connection("async3")
199+
200+
from mongoengine.connection import _connections
201+
202+
assert len(_connections) == 5
203+
204+
# Disconnect all async connections
205+
await disconnect_all_async()
206+
207+
# Verify only async connections were disconnected
208+
assert "sync1" in _connections
209+
assert "sync2" in _connections
210+
assert "async1" not in _connections
211+
assert "async2" not in _connections
212+
assert "async3" not in _connections
213+
214+
# Verify sync connections still work
215+
assert not is_async_connection("sync1")
216+
assert not is_async_connection("sync2")
217+
218+
# Clean up remaining sync connections
219+
disconnect("sync1")
220+
disconnect("sync2")
221+
222+
@pytest.mark.asyncio
223+
async def test_disconnect_all_async_empty(self):
224+
"""Test disconnect_all_async when no connections exist."""
225+
# Should not raise any errors
226+
await disconnect_all_async()
227+
228+
@pytest.mark.asyncio
229+
async def test_disconnect_all_async_only_sync(self):
230+
"""Test disconnect_all_async when only sync connections exist."""
231+
# Create only sync connections
232+
connect(db="sync_db1", alias="sync1")
233+
connect(db="sync_db2", alias="sync2")
234+
235+
from mongoengine.connection import _connections
236+
237+
assert len(_connections) == 2
238+
239+
# Disconnect all async (should do nothing)
240+
await disconnect_all_async()
241+
242+
# Verify sync connections still exist
243+
assert len(_connections) == 2
244+
assert "sync1" in _connections
245+
assert "sync2" in _connections
246+
247+
# Clean up
248+
disconnect("sync1")
249+
disconnect("sync2")

0 commit comments

Comments
 (0)