Skip to content

Commit 1270d90

Browse files
feat: add context manager support to Connection class (#1320)
Add __enter__ and __exit__ methods to Connection for use with Python's `with` statement. This enables automatic connection cleanup, particularly useful for serverless environments (AWS Lambda, Cloud Functions). Usage: with dj.Connection(host, user, password) as conn: schema = dj.schema('my_schema', connection=conn) # perform operations # connection automatically closed Closes #1081 Co-authored-by: Claude Opus 4.5 <[email protected]>
1 parent 1673be8 commit 1270d90

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

src/datajoint/connection.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,45 @@ def close(self) -> None:
287287
"""Close the database connection."""
288288
self._conn.close()
289289

290+
def __enter__(self) -> "Connection":
291+
"""
292+
Enter context manager.
293+
294+
Returns
295+
-------
296+
Connection
297+
This connection object.
298+
299+
Examples
300+
--------
301+
>>> with dj.Connection(host, user, password) as conn:
302+
... schema = dj.schema('my_schema', connection=conn)
303+
... # perform operations
304+
... # connection automatically closed
305+
"""
306+
return self
307+
308+
def __exit__(self, exc_type, exc_val, exc_tb) -> bool:
309+
"""
310+
Exit context manager and close connection.
311+
312+
Parameters
313+
----------
314+
exc_type : type or None
315+
Exception type if an exception was raised.
316+
exc_val : Exception or None
317+
Exception instance if an exception was raised.
318+
exc_tb : traceback or None
319+
Traceback if an exception was raised.
320+
321+
Returns
322+
-------
323+
bool
324+
False to propagate exceptions.
325+
"""
326+
self.close()
327+
return False
328+
290329
def register(self, schema) -> None:
291330
"""
292331
Register a schema with this connection.

tests/integration/test_connection.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,36 @@ def test_dj_connection_class(connection_test):
4646
assert connection_test.is_connected
4747

4848

49+
def test_connection_context_manager(db_creds_test):
50+
"""
51+
Connection should support context manager protocol for automatic cleanup.
52+
"""
53+
# Test basic context manager usage
54+
with dj.Connection(**db_creds_test) as conn:
55+
assert conn.is_connected
56+
# Verify we can use the connection
57+
result = conn.query("SELECT 1").fetchone()
58+
assert result[0] == 1
59+
60+
# Connection should be closed after exiting context
61+
assert not conn.is_connected
62+
63+
64+
def test_connection_context_manager_exception(db_creds_test):
65+
"""
66+
Connection should close even when exception is raised inside context.
67+
"""
68+
conn = None
69+
with pytest.raises(ValueError):
70+
with dj.Connection(**db_creds_test) as conn:
71+
assert conn.is_connected
72+
raise ValueError("Test exception")
73+
74+
# Connection should still be closed after exception
75+
assert conn is not None
76+
assert not conn.is_connected
77+
78+
4979
def test_persistent_dj_conn(db_creds_root):
5080
"""
5181
conn() method should provide persistent connection across calls.

0 commit comments

Comments
 (0)