Skip to content
12 changes: 12 additions & 0 deletions pymongo/asynchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -204,6 +205,7 @@ def __init__(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> None:
if snapshot:
if causal_consistency:
Expand All @@ -222,6 +224,7 @@ def __init__(
)
self._default_transaction_options = default_transaction_options
self._snapshot = snapshot
self._bind = bind
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to bind/unbind the session in ClientSession.__enter__/__exit__. That way the stack of sessions is managed correctly (ie we call _SESSION.reset(token)). Think about how nested cases will work:

session1 = client.start_session(bind=True)
with session1:
    session2 = client.start_session(bind=True)
    with session2:
        coll.find_one() # uses session2
    coll.find_one() # uses session1
coll.find_one() # uses implicit session

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 3c68a70


@property
def causal_consistency(self) -> bool:
Expand Down Expand Up @@ -514,6 +517,7 @@ def __init__(
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
self._token = None

async def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
Expand Down Expand Up @@ -545,9 +549,14 @@ def _check_ended(self) -> None:
raise InvalidOperation("Cannot use ended session")

async def __aenter__(self) -> AsyncClientSession:
if self._options._bind:
self._token = _SESSION.set(self)
return self

async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._token:
_SESSION.reset(self._token)
self._token = None
await self._end_session(lock=True)

@property
Expand Down Expand Up @@ -1065,6 +1074,9 @@ def __copy__(self) -> NoReturn:
raise TypeError("A AsyncClientSession cannot be copied, create a new session instead")


_SESSION: ContextVar[Optional[AsyncClientSession]] = ContextVar("SESSION", default=None)


class _EmptyServerSession:
__slots__ = "dirty", "started_retryable_write"

Expand Down
7 changes: 7 additions & 0 deletions pymongo/asynchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,16 @@ def __init__(
self._killed = False
self._session: Optional[AsyncClientSession]

from .client_session import _SESSION

bound_session = _SESSION.get()

if session:
self._session = session
self._explicit_session = True
elif bound_session:
self._session = bound_session
self._explicit_session = True
else:
self._session = None
self._explicit_session = False
Expand Down
2 changes: 2 additions & 0 deletions pymongo/asynchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,7 @@ def start_session(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[client_session.TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> client_session.AsyncClientSession:
"""Start a logical session.

Expand All @@ -1384,6 +1385,7 @@ def start_session(
causal_consistency=causal_consistency,
default_transaction_options=default_transaction_options,
snapshot=snapshot,
bind=bind,
)

def _ensure_session(
Expand Down
12 changes: 12 additions & 0 deletions pymongo/synchronous/client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
import time
import uuid
from collections.abc import Mapping as _Mapping
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -203,6 +204,7 @@ def __init__(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> None:
if snapshot:
if causal_consistency:
Expand All @@ -221,6 +223,7 @@ def __init__(
)
self._default_transaction_options = default_transaction_options
self._snapshot = snapshot
self._bind = bind

@property
def causal_consistency(self) -> bool:
Expand Down Expand Up @@ -513,6 +516,7 @@ def __init__(
# Is this an implicitly created session?
self._implicit = implicit
self._transaction = _Transaction(None, client)
self._token = None

def end_session(self) -> None:
"""Finish this session. If a transaction has started, abort it.
Expand Down Expand Up @@ -544,9 +548,14 @@ def _check_ended(self) -> None:
raise InvalidOperation("Cannot use ended session")

def __enter__(self) -> ClientSession:
if self._options._bind:
self._token = _SESSION.set(self)
return self

def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._token:
_SESSION.reset(self._token)
self._token = None
self._end_session(lock=True)

@property
Expand Down Expand Up @@ -1060,6 +1069,9 @@ def __copy__(self) -> NoReturn:
raise TypeError("A ClientSession cannot be copied, create a new session instead")


_SESSION: ContextVar[Optional[ClientSession]] = ContextVar("SESSION", default=None)


class _EmptyServerSession:
__slots__ = "dirty", "started_retryable_write"

Expand Down
7 changes: 7 additions & 0 deletions pymongo/synchronous/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,16 @@ def __init__(
self._killed = False
self._session: Optional[ClientSession]

from .client_session import _SESSION

bound_session = _SESSION.get()

if session:
self._session = session
self._explicit_session = True
elif bound_session:
self._session = bound_session
self._explicit_session = True
else:
self._session = None
self._explicit_session = False
Expand Down
2 changes: 2 additions & 0 deletions pymongo/synchronous/mongo_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,6 +1360,7 @@ def start_session(
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[client_session.TransactionOptions] = None,
snapshot: Optional[bool] = False,
bind: Optional[bool] = False,
) -> client_session.ClientSession:
"""Start a logical session.

Expand All @@ -1382,6 +1383,7 @@ def start_session(
causal_consistency=causal_consistency,
default_transaction_options=default_transaction_options,
snapshot=snapshot,
bind=bind,
)

def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]:
Expand Down
24 changes: 24 additions & 0 deletions test/asynchronous/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,30 @@ async def test_cursor_clone(self):
await cursor.close()
await clone.close()

async def test_bind_session(self):
coll = self.client.pymongo_test.collection

# Explicit session via context variable.
async with self.client.start_session(bind=True) as s:
cursor = coll.find()
self.assertTrue(cursor.session is s)

# Nested sessions.

async def get_cursor(collection):
return collection.find()

session1 = self.client.start_session(bind=True)
async with session1:
session2 = self.client.start_session(bind=True)
async with session2:
cursor = await get_cursor(coll) # uses session2
self.assertEqual(cursor.session, session2)
cursor = await get_cursor(coll) # uses session1
self.assertEqual(cursor.session, session1)
cursor = await get_cursor(coll) # uses implicit session
self.assertEqual(cursor.session, None)

async def test_cursor(self):
listener = self.listener
client = self.client
Expand Down
24 changes: 24 additions & 0 deletions test/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,30 @@ def test_cursor_clone(self):
cursor.close()
clone.close()

def test_bind_session(self):
coll = self.client.pymongo_test.collection

# Explicit session via context variable.
with self.client.start_session(bind=True) as s:
cursor = coll.find()
self.assertTrue(cursor.session is s)

# Nested sessions.

def get_cursor(collection):
return collection.find()

session1 = self.client.start_session(bind=True)
with session1:
session2 = self.client.start_session(bind=True)
with session2:
cursor = get_cursor(coll) # uses session2
self.assertEqual(cursor.session, session2)
cursor = get_cursor(coll) # uses session1
self.assertEqual(cursor.session, session1)
cursor = get_cursor(coll) # uses implicit session
self.assertEqual(cursor.session, None)

def test_cursor(self):
listener = self.listener
client = self.client
Expand Down
Loading