Skip to content

Commit ebb1d62

Browse files
committed
PYTHON-4542 Improved sessions API
- Via context variable.
1 parent d0b0dc3 commit ebb1d62

File tree

8 files changed

+54
-6
lines changed

8 files changed

+54
-6
lines changed

pymongo/asynchronous/client_session.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
import time
140140
import uuid
141141
from collections.abc import Mapping as _Mapping
142+
from contextvars import ContextVar
142143
from typing import (
143144
TYPE_CHECKING,
144145
Any,
@@ -204,6 +205,7 @@ def __init__(
204205
causal_consistency: Optional[bool] = None,
205206
default_transaction_options: Optional[TransactionOptions] = None,
206207
snapshot: Optional[bool] = False,
208+
bind: Optional[bool] = False,
207209
) -> None:
208210
if snapshot:
209211
if causal_consistency:
@@ -222,6 +224,7 @@ def __init__(
222224
)
223225
self._default_transaction_options = default_transaction_options
224226
self._snapshot = snapshot
227+
self._bind = bind
225228

226229
@property
227230
def causal_consistency(self) -> bool:
@@ -1065,6 +1068,9 @@ def __copy__(self) -> NoReturn:
10651068
raise TypeError("A AsyncClientSession cannot be copied, create a new session instead")
10661069

10671070

1071+
SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)
1072+
1073+
10681074
class _EmptyServerSession:
10691075
__slots__ = "dirty", "started_retryable_write"
10701076

pymongo/asynchronous/cursor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@
6363
from _typeshed import SupportsItems
6464

6565
from bson.codec_options import CodecOptions
66-
from pymongo.asynchronous.client_session import AsyncClientSession
66+
from pymongo.asynchronous.client_session import SESSION, AsyncClientSession
6767
from pymongo.asynchronous.collection import AsyncCollection
6868
from pymongo.asynchronous.pool import AsyncConnection
6969
from pymongo.read_preferences import _ServerMode
@@ -136,9 +136,14 @@ def __init__(
136136
self._killed = False
137137
self._session: Optional[AsyncClientSession]
138138

139+
_SESSION = SESSION.get()
140+
139141
if session:
140142
self._session = session
141143
self._explicit_session = True
144+
elif _SESSION:
145+
self._session = _SESSION
146+
self._explicit_session = True
142147
else:
143148
self._session = None
144149
self._explicit_session = False

pymongo/asynchronous/mongo_client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
from pymongo.asynchronous import client_session, database, uri_parser
6666
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
6767
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
68-
from pymongo.asynchronous.client_session import _EmptyServerSession
68+
from pymongo.asynchronous.client_session import SESSION, _EmptyServerSession
6969
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
7070
from pymongo.asynchronous.settings import TopologySettings
7171
from pymongo.asynchronous.topology import Topology, _ErrorContext
@@ -1355,13 +1355,18 @@ def _close_cursor_soon(
13551355
def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession:
13561356
server_session = _EmptyServerSession()
13571357
opts = client_session.SessionOptions(**kwargs)
1358-
return client_session.AsyncClientSession(self, server_session, opts, implicit)
1358+
bind = opts._bind
1359+
session = client_session.AsyncClientSession(self, server_session, opts, implicit)
1360+
if bind:
1361+
SESSION.set(session)
1362+
return session
13591363

13601364
def start_session(
13611365
self,
13621366
causal_consistency: Optional[bool] = None,
13631367
default_transaction_options: Optional[client_session.TransactionOptions] = None,
13641368
snapshot: Optional[bool] = False,
1369+
bind: Optional[bool] = False,
13651370
) -> client_session.AsyncClientSession:
13661371
"""Start a logical session.
13671372
@@ -1384,6 +1389,7 @@ def start_session(
13841389
causal_consistency=causal_consistency,
13851390
default_transaction_options=default_transaction_options,
13861391
snapshot=snapshot,
1392+
bind=bind,
13871393
)
13881394

13891395
def _ensure_session(

pymongo/synchronous/client_session.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
import time
140140
import uuid
141141
from collections.abc import Mapping as _Mapping
142+
from contextvars import ContextVar
142143
from typing import (
143144
TYPE_CHECKING,
144145
Any,
@@ -203,6 +204,7 @@ def __init__(
203204
causal_consistency: Optional[bool] = None,
204205
default_transaction_options: Optional[TransactionOptions] = None,
205206
snapshot: Optional[bool] = False,
207+
bind: Optional[bool] = False,
206208
) -> None:
207209
if snapshot:
208210
if causal_consistency:
@@ -221,6 +223,7 @@ def __init__(
221223
)
222224
self._default_transaction_options = default_transaction_options
223225
self._snapshot = snapshot
226+
self._bind = bind
224227

225228
@property
226229
def causal_consistency(self) -> bool:
@@ -1060,6 +1063,9 @@ def __copy__(self) -> NoReturn:
10601063
raise TypeError("A ClientSession cannot be copied, create a new session instead")
10611064

10621065

1066+
SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None)
1067+
1068+
10631069
class _EmptyServerSession:
10641070
__slots__ = "dirty", "started_retryable_write"
10651071

pymongo/synchronous/cursor.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464

6565
from bson.codec_options import CodecOptions
6666
from pymongo.read_preferences import _ServerMode
67-
from pymongo.synchronous.client_session import ClientSession
67+
from pymongo.synchronous.client_session import SESSION, ClientSession
6868
from pymongo.synchronous.collection import Collection
6969
from pymongo.synchronous.pool import Connection
7070

@@ -136,9 +136,14 @@ def __init__(
136136
self._killed = False
137137
self._session: Optional[ClientSession]
138138

139+
_SESSION = SESSION.get()
140+
139141
if session:
140142
self._session = session
141143
self._explicit_session = True
144+
elif _SESSION:
145+
self._session = _SESSION
146+
self._explicit_session = True
142147
else:
143148
self._session = None
144149
self._explicit_session = False

pymongo/synchronous/mongo_client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@
107107
from pymongo.synchronous import client_session, database, uri_parser
108108
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
109109
from pymongo.synchronous.client_bulk import _ClientBulk
110-
from pymongo.synchronous.client_session import _EmptyServerSession
110+
from pymongo.synchronous.client_session import SESSION, _EmptyServerSession
111111
from pymongo.synchronous.command_cursor import CommandCursor
112112
from pymongo.synchronous.settings import TopologySettings
113113
from pymongo.synchronous.topology import Topology, _ErrorContext
@@ -1353,13 +1353,18 @@ def _close_cursor_soon(
13531353
def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession:
13541354
server_session = _EmptyServerSession()
13551355
opts = client_session.SessionOptions(**kwargs)
1356-
return client_session.ClientSession(self, server_session, opts, implicit)
1356+
bind = opts._bind
1357+
session = client_session.ClientSession(self, server_session, opts, implicit)
1358+
if bind:
1359+
SESSION.set(session)
1360+
return session
13571361

13581362
def start_session(
13591363
self,
13601364
causal_consistency: Optional[bool] = None,
13611365
default_transaction_options: Optional[client_session.TransactionOptions] = None,
13621366
snapshot: Optional[bool] = False,
1367+
bind: Optional[bool] = False,
13631368
) -> client_session.ClientSession:
13641369
"""Start a logical session.
13651370
@@ -1382,6 +1387,7 @@ def start_session(
13821387
causal_consistency=causal_consistency,
13831388
default_transaction_options=default_transaction_options,
13841389
snapshot=snapshot,
1390+
bind=bind,
13851391
)
13861392

13871393
def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:

test/asynchronous/test_session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,13 @@ async def test_cursor_clone(self):
380380
clone = cursor.clone()
381381
self.assertTrue(clone.session is s)
382382

383+
# Explicit session via context variable.
384+
async with self.client.start_session(bind=True) as s:
385+
cursor = coll.find()
386+
self.assertTrue(cursor.session is s)
387+
clone = cursor.clone()
388+
self.assertTrue(clone.session is s)
389+
383390
# No explicit session.
384391
cursor = coll.find(batch_size=2)
385392
await anext(cursor)

test/test_session.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,13 @@ def test_cursor_clone(self):
380380
clone = cursor.clone()
381381
self.assertTrue(clone.session is s)
382382

383+
# Explicit session via context variable.
384+
with self.client.start_session(bind=True) as s:
385+
cursor = coll.find()
386+
self.assertTrue(cursor.session is s)
387+
clone = cursor.clone()
388+
self.assertTrue(clone.session is s)
389+
383390
# No explicit session.
384391
cursor = coll.find(batch_size=2)
385392
next(cursor)

0 commit comments

Comments
 (0)