Skip to content
6 changes: 6 additions & 0 deletions pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -204,6 +205,7 @@ def __init__(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> None:
if snapshot:
if causal_consistency:
Expand All @@ -222,6 +224,7 @@ def __init__(
)
self._default_transaction_options = default_transaction_options
self._snapshot = snapshot
self._bind = bind
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to bind/unbind the session in ClientSession.__enter__/__exit__. That way the stack of sessions is managed correctly (ie we call _SESSION.reset(token)). Think about how nested cases will work:

session1 = client.start_session(bind=True)
with session1:
    session2 = client.start_session(bind=True)
    with session2:
        coll.find_one() # uses session2
    coll.find_one() # uses session1
coll.find_one() # uses implicit session

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 3c68a70


@property
def causal_consistency(self) -> bool:
Expand Down Expand Up @@ -1065,6 +1068,9 @@ def __copy__(self) -> NoReturn:
raise TypeError("A AsyncClientSession cannot be copied, create a new session instead")


SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)


class _EmptyServerSession:
__slots__ = "dirty", "started_retryable_write"

Expand Down
7 changes: 6 additions & 1 deletion pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from _typeshed import SupportsItems

from bson.codec_options import CodecOptions
from pymongo.asynchronous.client_session import AsyncClientSession
from pymongo.asynchronous.client_session import SESSION, AsyncClientSession
from pymongo.asynchronous.collection import AsyncCollection
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.read_preferences import _ServerMode
Expand Down Expand Up @@ -136,9 +136,14 @@ def __init__(
self._killed = False
self._session: Optional[AsyncClientSession]

_SESSION = SESSION.get()

if session:
self._session = session
self._explicit_session = True
elif _SESSION:
self._session = _SESSION
self._explicit_session = True
else:
self._session = None
self._explicit_session = False
Expand Down
10 changes: 8 additions & 2 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from pymongo.asynchronous import client_session, database, uri_parser
from pymongo.asynchronous.change_stream import AsyncChangeStream, AsyncClusterChangeStream
from pymongo.asynchronous.client_bulk import _AsyncClientBulk
from pymongo.asynchronous.client_session import _EmptyServerSession
from pymongo.asynchronous.client_session import SESSION, _EmptyServerSession
from pymongo.asynchronous.command_cursor import AsyncCommandCursor
from pymongo.asynchronous.settings import TopologySettings
from pymongo.asynchronous.topology import Topology, _ErrorContext
Expand Down Expand Up @@ -1355,13 +1355,18 @@ def _close_cursor_soon(
def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession:
server_session = _EmptyServerSession()
opts = client_session.SessionOptions(**kwargs)
return client_session.AsyncClientSession(self, server_session, opts, implicit)
bind = opts._bind
session = client_session.AsyncClientSession(self, server_session, opts, implicit)
if bind:
SESSION.set(session)
return session

def start_session(
self,
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[client_session.TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> client_session.AsyncClientSession:
"""Start a logical session.

Expand All @@ -1384,6 +1389,7 @@ def start_session(
causal_consistency=causal_consistency,
default_transaction_options=default_transaction_options,
snapshot=snapshot,
bind=bind,
)

def _ensure_session(
Expand Down
6 changes: 6 additions & 0 deletions pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -203,6 +204,7 @@ def __init__(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> None:
if snapshot:
if causal_consistency:
Expand All @@ -221,6 +223,7 @@ def __init__(
)
self._default_transaction_options = default_transaction_options
self._snapshot = snapshot
self._bind = bind

@property
def causal_consistency(self) -> bool:
Expand Down Expand Up @@ -1060,6 +1063,9 @@ def __copy__(self) -> NoReturn:
raise TypeError("A ClientSession cannot be copied, create a new session instead")


SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None)


class _EmptyServerSession:
__slots__ = "dirty", "started_retryable_write"

Expand Down
7 changes: 6 additions & 1 deletion pymongo/synchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@

from bson.codec_options import CodecOptions
from pymongo.read_preferences import _ServerMode
from pymongo.synchronous.client_session import ClientSession
from pymongo.synchronous.client_session import SESSION, ClientSession
from pymongo.synchronous.collection import Collection
from pymongo.synchronous.pool import Connection

Expand Down Expand Up @@ -136,9 +136,14 @@ def __init__(
self._killed = False
self._session: Optional[ClientSession]

_SESSION = SESSION.get()

if session:
self._session = session
self._explicit_session = True
elif _SESSION:
self._session = _SESSION
self._explicit_session = True
else:
self._session = None
self._explicit_session = False
Expand Down
10 changes: 8 additions & 2 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
from pymongo.synchronous import client_session, database, uri_parser
from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream
from pymongo.synchronous.client_bulk import _ClientBulk
from pymongo.synchronous.client_session import _EmptyServerSession
from pymongo.synchronous.client_session import SESSION, _EmptyServerSession
from pymongo.synchronous.command_cursor import CommandCursor
from pymongo.synchronous.settings import TopologySettings
from pymongo.synchronous.topology import Topology, _ErrorContext
Expand Down Expand Up @@ -1353,13 +1353,18 @@ def _close_cursor_soon(
def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession:
server_session = _EmptyServerSession()
opts = client_session.SessionOptions(**kwargs)
return client_session.ClientSession(self, server_session, opts, implicit)
bind = opts._bind
session = client_session.ClientSession(self, server_session, opts, implicit)
if bind:
SESSION.set(session)
return session

def start_session(
self,
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[client_session.TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> client_session.ClientSession:
"""Start a logical session.

Expand All @@ -1382,6 +1387,7 @@ def start_session(
causal_consistency=causal_consistency,
default_transaction_options=default_transaction_options,
snapshot=snapshot,
bind=bind,
)

def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:
Expand Down
7 changes: 7 additions & 0 deletions test/asynchronous/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,13 @@ async def test_cursor_clone(self):
clone = cursor.clone()
self.assertTrue(clone.session is s)

# Explicit session via context variable.
async with self.client.start_session(bind=True) as s:
cursor = coll.find()
self.assertTrue(cursor.session is s)
clone = cursor.clone()
self.assertTrue(clone.session is s)

# No explicit session.
cursor = coll.find(batch_size=2)
await anext(cursor)
Expand Down
7 changes: 7 additions & 0 deletions test/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,13 @@ def test_cursor_clone(self):
clone = cursor.clone()
self.assertTrue(clone.session is s)

# Explicit session via context variable.
with self.client.start_session(bind=True) as s:
cursor = coll.find()
self.assertTrue(cursor.session is s)
clone = cursor.clone()
self.assertTrue(clone.session is s)

# No explicit session.
cursor = coll.find(batch_size=2)
next(cursor)
Expand Down
Loading