Skip to content

Commit b737b84

Browse files
authored
PYTHON-2956 Drivers should check out an implicit session only after checking out a connection (#876)
1 parent 782c551 commit b737b84

File tree

4 files changed

+106
-11
lines changed

4 files changed

+106
-11
lines changed

pymongo/client_session.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -947,9 +947,16 @@ def _txn_read_preference(self):
947947
return self._transaction.opts.read_preference
948948
return None
949949

950+
def _materialize(self):
951+
if isinstance(self._server_session, _EmptyServerSession):
952+
old = self._server_session
953+
self._server_session = self._client._topology.get_server_session()
954+
if old.started_retryable_write:
955+
self._server_session.inc_transaction_id()
956+
950957
def _apply_to(self, command, is_retryable, read_preference, sock_info):
951958
self._check_ended()
952-
959+
self._materialize()
953960
if self.options.snapshot:
954961
self._update_read_concern(command, sock_info)
955962

@@ -1000,6 +1007,20 @@ def __copy__(self):
10001007
raise TypeError("A ClientSession cannot be copied, create a new session instead")
10011008

10021009

1010+
class _EmptyServerSession:
1011+
__slots__ = "dirty", "started_retryable_write"
1012+
1013+
def __init__(self):
1014+
self.dirty = False
1015+
self.started_retryable_write = False
1016+
1017+
def mark_dirty(self):
1018+
self.dirty = True
1019+
1020+
def inc_transaction_id(self):
1021+
self.started_retryable_write = True
1022+
1023+
10031024
class _ServerSession(object):
10041025
def __init__(self, generation):
10051026
# Ensure id is type 4, regardless of CodecOptions.uuid_representation.

pymongo/mongo_client.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
)
6767
from pymongo.change_stream import ChangeStream, ClusterChangeStream
6868
from pymongo.client_options import ClientOptions
69+
from pymongo.client_session import _EmptyServerSession
6970
from pymongo.command_cursor import CommandCursor
7071
from pymongo.errors import (
7172
AutoReconnect,
@@ -1601,7 +1602,11 @@ def _process_periodic_tasks(self):
16011602

16021603
def __start_session(self, implicit, **kwargs):
16031604
# Raises ConfigurationError if sessions are not supported.
1604-
server_session = self._get_server_session()
1605+
if implicit:
1606+
self._topology._check_implicit_session_support()
1607+
server_session = _EmptyServerSession()
1608+
else:
1609+
server_session = self._get_server_session()
16051610
opts = client_session.SessionOptions(**kwargs)
16061611
return client_session.ClientSession(self, server_session, opts, implicit)
16071612

@@ -1641,6 +1646,8 @@ def _get_server_session(self):
16411646

16421647
def _return_server_session(self, server_session, lock):
16431648
"""Internal: return a _ServerSession to the pool."""
1649+
if isinstance(server_session, _EmptyServerSession):
1650+
return
16441651
return self._topology.return_server_session(server_session, lock)
16451652

16461653
def _ensure_session(self, session=None):

pymongo/topology.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -514,8 +514,15 @@ def pop_all_sessions(self):
514514
with self._lock:
515515
return self._session_pool.pop_all()
516516

517+
def _check_implicit_session_support(self):
518+
with self._lock:
519+
self._check_session_support()
520+
517521
def _check_session_support(self):
518-
"""Internal check for session support on non-load balanced clusters."""
522+
"""Internal check for session support on clusters."""
523+
if self._settings.load_balanced:
524+
# Sessions never time out in load balanced mode.
525+
return float("inf")
519526
session_timeout = self._description.logical_session_timeout_minutes
520527
if session_timeout is None:
521528
# Maybe we need an initial scan? Can raise ServerSelectionError.
@@ -537,12 +544,7 @@ def _check_session_support(self):
537544
def get_server_session(self):
538545
"""Start or resume a server session, or raise ConfigurationError."""
539546
with self._lock:
540-
# Sessions are always supported in load balanced mode.
541-
if not self._settings.load_balanced:
542-
session_timeout = self._check_session_support()
543-
else:
544-
# Sessions never time out in load balanced mode.
545-
session_timeout = float("inf")
547+
session_timeout = self._check_session_support()
546548
return self._session_pool.get_server_session(session_timeout)
547549

548550
def return_server_session(self, server_session, lock):

test/test_session.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,28 @@
1818
import sys
1919
import time
2020
from io import BytesIO
21-
from typing import Set
21+
from typing import Any, Callable, List, Set, Tuple
2222

2323
from pymongo.mongo_client import MongoClient
2424

2525
sys.path[0:0] = [""]
2626

2727
from test import IntegrationTest, SkipTest, client_context, unittest
28-
from test.utils import EventListener, rs_or_single_client, wait_until
28+
from test.utils import (
29+
EventListener,
30+
ExceptionCatchingThread,
31+
rs_or_single_client,
32+
wait_until,
33+
)
2934

3035
from bson import DBRef
3136
from gridfs import GridFS, GridFSBucket
3237
from pymongo import ASCENDING, IndexModel, InsertOne, monitoring
38+
from pymongo.command_cursor import CommandCursor
3339
from pymongo.common import _MAX_END_SESSIONS
40+
from pymongo.cursor import Cursor
3441
from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure
42+
from pymongo.operations import UpdateOne
3543
from pymongo.read_concern import ReadConcern
3644

3745

@@ -171,6 +179,63 @@ def _test_ops(self, client, *ops):
171179
"%s did not return implicit session to pool" % (f.__name__,),
172180
)
173181

182+
def test_implicit_sessions_checkout(self):
183+
# "To confirm that implicit sessions only allocate their server session after a
184+
# successful connection checkout" test from Driver Sessions Spec.
185+
succeeded = False
186+
failures = 0
187+
for _ in range(5):
188+
listener = EventListener()
189+
client = rs_or_single_client(
190+
event_listeners=[listener], maxPoolSize=1, retryWrites=True
191+
)
192+
cursor = client.db.test.find({})
193+
ops: List[Tuple[Callable, List[Any]]] = [
194+
(client.db.test.find_one, [{"_id": 1}]),
195+
(client.db.test.delete_one, [{}]),
196+
(client.db.test.update_one, [{}, {"$set": {"x": 2}}]),
197+
(client.db.test.bulk_write, [[UpdateOne({}, {"$set": {"x": 2}})]]),
198+
(client.db.test.find_one_and_delete, [{}]),
199+
(client.db.test.find_one_and_update, [{}, {"$set": {"x": 1}}]),
200+
(client.db.test.find_one_and_replace, [{}, {}]),
201+
(client.db.test.aggregate, [[{"$limit": 1}]]),
202+
(client.db.test.find, []),
203+
(client.server_info, [{}]),
204+
(client.db.aggregate, [[{"$listLocalSessions": {}}, {"$limit": 1}]]),
205+
(cursor.distinct, ["_id"]),
206+
(client.db.list_collections, []),
207+
]
208+
threads = []
209+
listener.results.clear()
210+
211+
def thread_target(op, *args):
212+
res = op(*args)
213+
if isinstance(res, (Cursor, CommandCursor)):
214+
list(res)
215+
216+
for op, args in ops:
217+
threads.append(
218+
ExceptionCatchingThread(
219+
target=thread_target, args=[op, *args], name=op.__name__
220+
)
221+
)
222+
threads[-1].start()
223+
self.assertEqual(len(threads), len(ops))
224+
for thread in threads:
225+
thread.join()
226+
self.assertIsNone(thread.exc)
227+
client.close()
228+
lsid_set = set()
229+
for i in listener.results["started"]:
230+
if i.command.get("lsid"):
231+
lsid_set.add(i.command.get("lsid")["id"])
232+
if len(lsid_set) == 1:
233+
succeeded = True
234+
else:
235+
failures += 1
236+
print(failures)
237+
self.assertTrue(succeeded)
238+
174239
def test_pool_lifo(self):
175240
# "Pool is LIFO" test from Driver Sessions Spec.
176241
a = self.client.start_session()

0 commit comments

Comments
 (0)