Skip to content

Commit 3204290

Browse files
authored
PYTHON-2484 Added lock sanitization for MongoClient and ObjectId (#985)
1 parent 46673c3 commit 3204290

File tree

11 files changed

+219
-29
lines changed

11 files changed

+219
-29
lines changed

bson/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656

5757
import datetime
5858
import itertools
59+
import os
5960
import re
6061
import struct
6162
import sys
@@ -1336,3 +1337,16 @@ def decode(self, codec_options: "CodecOptions[_DocumentType]" = DEFAULT_CODEC_OP
13361337
def has_c() -> bool:
13371338
"""Is the C extension installed?"""
13381339
return _USE_C
1340+
1341+
1342+
def _after_fork():
1343+
"""Releases the ObjectID lock child."""
1344+
if ObjectId._inc_lock.locked():
1345+
ObjectId._inc_lock.release()
1346+
1347+
1348+
if hasattr(os, "register_at_fork"):
1349+
# This will run in the same thread as the fork was called.
1350+
# If we fork in a critical region on the same thread, it should break.
1351+
# This is fine since we would never call fork directly from a critical region.
1352+
os.register_at_fork(after_in_child=_after_fork)

pymongo/cursor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
"""Cursor class to iterate over Mongo query results."""
1616
import copy
17-
import threading
1817
import warnings
1918
from collections import deque
2019
from typing import (
@@ -45,6 +44,7 @@
4544
validate_is_mapping,
4645
)
4746
from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure
47+
from pymongo.lock import _create_lock
4848
from pymongo.message import (
4949
_CursorAddress,
5050
_GetMore,
@@ -133,7 +133,7 @@ def __init__(self, sock, more_to_come):
133133
self.sock = sock
134134
self.more_to_come = more_to_come
135135
self.closed = False
136-
self.lock = threading.Lock()
136+
self.lock = _create_lock()
137137

138138
def update_exhaust(self, more_to_come):
139139
self.more_to_come = more_to_come

pymongo/lock.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright 2022-present MongoDB, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
import threading
17+
import weakref
18+
19+
_HAS_REGISTER_AT_FORK = hasattr(os, "register_at_fork")
20+
21+
# References to instances of _create_lock
22+
_forkable_locks: weakref.WeakSet = weakref.WeakSet()
23+
24+
25+
def _create_lock():
26+
"""Represents a lock that is tracked upon instantiation using a WeakSet and
27+
reset by pymongo upon forking.
28+
"""
29+
lock = threading.Lock()
30+
if _HAS_REGISTER_AT_FORK:
31+
_forkable_locks.add(lock)
32+
return lock
33+
34+
35+
def _release_locks() -> None:
36+
# Completed the fork, reset all the locks in the child.
37+
for lock in _forkable_locks:
38+
if lock.locked():
39+
lock.release()

pymongo/mongo_client.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
"""
3333

3434
import contextlib
35-
import threading
35+
import os
3636
import weakref
3737
from collections import defaultdict
3838
from typing import (
@@ -82,6 +82,7 @@
8282
ServerSelectionTimeoutError,
8383
WaitQueueTimeoutError,
8484
)
85+
from pymongo.lock import _create_lock, _release_locks
8586
from pymongo.pool import ConnectionClosedReason
8687
from pymongo.read_preferences import ReadPreference, _ServerMode
8788
from pymongo.server_selectors import writable_server_selector
@@ -126,6 +127,7 @@ class MongoClient(common.BaseObject, Generic[_DocumentType]):
126127
# Define order to retrieve options from ClientOptions for __repr__.
127128
# No host/port; these are retrieved from TopologySettings.
128129
_constructor_args = ("document_class", "tz_aware", "connect")
130+
_clients: weakref.WeakValueDictionary = weakref.WeakValueDictionary()
129131

130132
def __init__(
131133
self,
@@ -788,7 +790,7 @@ def __init__(
788790
self.__options = options = ClientOptions(username, password, dbase, opts)
789791

790792
self.__default_database_name = dbase
791-
self.__lock = threading.Lock()
793+
self.__lock = _create_lock()
792794
self.__kill_cursors_queue: List = []
793795

794796
self._event_listeners = options.pool_options._event_listeners
@@ -817,6 +819,23 @@ def __init__(
817819
srv_max_hosts=srv_max_hosts,
818820
)
819821

822+
self._init_background()
823+
824+
if connect:
825+
self._get_topology()
826+
827+
self._encrypter = None
828+
if self.__options.auto_encryption_opts:
829+
from pymongo.encryption import _Encrypter
830+
831+
self._encrypter = _Encrypter(self, self.__options.auto_encryption_opts)
832+
self._timeout = self.__options.timeout
833+
834+
# Add this client to the list of weakly referenced items.
835+
# This will be used later if we fork.
836+
MongoClient._clients[self._topology._topology_id] = self
837+
838+
def _init_background(self):
820839
self._topology = Topology(self._topology_settings)
821840

822841
def target():
@@ -838,15 +857,9 @@ def target():
838857
self_ref: Any = weakref.ref(self, executor.close)
839858
self._kill_cursors_executor = executor
840859

841-
if connect:
842-
self._get_topology()
843-
844-
self._encrypter = None
845-
if self.__options.auto_encryption_opts:
846-
from pymongo.encryption import _Encrypter
847-
848-
self._encrypter = _Encrypter(self, self.__options.auto_encryption_opts)
849-
self._timeout = options.timeout
860+
def _after_fork(self):
861+
"""Resets topology in a child after successfully forking."""
862+
self._init_background()
850863

851864
def _duplicate(self, **kwargs):
852865
args = self.__init_kwargs.copy()
@@ -2150,3 +2163,22 @@ def __enter__(self):
21502163

21512164
def __exit__(self, exc_type, exc_val, exc_tb):
21522165
return self.handle(exc_type, exc_val)
2166+
2167+
2168+
def _after_fork_child():
2169+
"""Releases the locks in child process and resets the
2170+
topologies in all MongoClients.
2171+
"""
2172+
# Reinitialize locks
2173+
_release_locks()
2174+
2175+
# Perform cleanup in clients (i.e. get rid of topology)
2176+
for _, client in MongoClient._clients.items():
2177+
client._after_fork()
2178+
2179+
2180+
if hasattr(os, "register_at_fork"):
2181+
# This will run in the same thread as the fork was called.
2182+
# If we fork in a critical region on the same thread, it should break.
2183+
# This is fine since we would never call fork directly from a critical region.
2184+
os.register_at_fork(after_in_child=_after_fork_child)

pymongo/monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
"""Class to monitor a MongoDB server on a background thread."""
1616

1717
import atexit
18-
import threading
1918
import time
2019
import weakref
2120
from typing import Any, Mapping, cast
2221

2322
from pymongo import common, periodic_executor
2423
from pymongo.errors import NotPrimaryError, OperationFailure, _OperationCancelled
2524
from pymongo.hello import Hello
25+
from pymongo.lock import _create_lock
2626
from pymongo.periodic_executor import _shutdown_executors
2727
from pymongo.read_preferences import MovingAverage
2828
from pymongo.server_description import ServerDescription
@@ -350,7 +350,7 @@ def __init__(self, topology, topology_settings, pool):
350350

351351
self._pool = pool
352352
self._moving_average = MovingAverage()
353-
self._lock = threading.Lock()
353+
self._lock = _create_lock()
354354

355355
def close(self):
356356
self.gc_safe_close()

pymongo/ocsp_cache.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616

1717
from collections import namedtuple
1818
from datetime import datetime as _datetime
19-
from threading import Lock
19+
20+
from pymongo.lock import _create_lock
2021

2122

2223
class _OCSPCache(object):
@@ -30,7 +31,7 @@ class _OCSPCache(object):
3031
def __init__(self):
3132
self._data = {}
3233
# Hold this lock when accessing _data.
33-
self._lock = Lock()
34+
self._lock = _create_lock()
3435

3536
def _get_cache_key(self, ocsp_request):
3637
return self.CACHE_KEY_TYPE(

pymongo/periodic_executor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import weakref
2020
from typing import Any, Optional
2121

22+
from pymongo.lock import _create_lock
23+
2224

2325
class PeriodicExecutor(object):
2426
def __init__(self, interval, min_interval, target, name=None):
@@ -45,9 +47,8 @@ def __init__(self, interval, min_interval, target, name=None):
4547
self._thread: Optional[threading.Thread] = None
4648
self._name = name
4749
self._skip_sleep = False
48-
4950
self._thread_will_exit = False
50-
self._lock = threading.Lock()
51+
self._lock = _create_lock()
5152

5253
def __repr__(self):
5354
return "<%s(name=%s) object at 0x%x>" % (self.__class__.__name__, self._name, id(self))

pymongo/pool.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
_CertificateError,
5757
)
5858
from pymongo.hello import Hello, HelloCompat
59+
from pymongo.lock import _create_lock
5960
from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason
6061
from pymongo.network import command, receive_message
6162
from pymongo.read_preferences import ReadPreference
@@ -1152,7 +1153,7 @@ def __init__(self, address, options, handshake=True):
11521153
# and returned to pool from the left side. Stale sockets removed
11531154
# from the right side.
11541155
self.sockets: collections.deque = collections.deque()
1155-
self.lock = threading.Lock()
1156+
self.lock = _create_lock()
11561157
self.active_sockets = 0
11571158
# Monotonically increasing connection ID required for CMAP Events.
11581159
self.next_connection_id = 1

pymongo/topology.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import os
1818
import queue
1919
import random
20-
import threading
2120
import time
2221
import warnings
2322
import weakref
@@ -37,6 +36,7 @@
3736
WriteError,
3837
)
3938
from pymongo.hello import Hello
39+
from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock
4040
from pymongo.monitor import SrvMonitor
4141
from pymongo.pool import PoolOptions
4242
from pymongo.server import Server
@@ -127,7 +127,7 @@ def __init__(self, topology_settings):
127127
self._seed_addresses = list(topology_description.server_descriptions())
128128
self._opened = False
129129
self._closed = False
130-
self._lock = threading.Lock()
130+
self._lock = _create_lock()
131131
self._condition = self._settings.condition_class(self._lock)
132132
self._servers = {}
133133
self._pid = None
@@ -174,12 +174,13 @@ def open(self):
174174
self._pid = pid
175175
elif pid != self._pid:
176176
self._pid = pid
177-
warnings.warn(
178-
"MongoClient opened before fork. Create MongoClient only "
179-
"after forking. See PyMongo's documentation for details: "
180-
"https://pymongo.readthedocs.io/en/stable/faq.html#"
181-
"is-pymongo-fork-safe"
182-
)
177+
if not _HAS_REGISTER_AT_FORK:
178+
warnings.warn(
179+
"MongoClient opened before fork. May not be entirely fork-safe, "
180+
"proceed with caution. See PyMongo's documentation for details: "
181+
"https://pymongo.readthedocs.io/en/stable/faq.html#"
182+
"is-pymongo-fork-safe"
183+
)
183184
with self._lock:
184185
# Close servers and clear the pools.
185186
for server in self._servers.values():

0 commit comments

Comments
 (0)