Skip to content

Commit bf8b036

Browse files
committed
PYTHON-2674 Pool.reset only clears connections to the given serviceId (#628)
(cherry picked from commit 112ee69)
1 parent a2d687b commit bf8b036

File tree

8 files changed

+100
-25
lines changed

8 files changed

+100
-25
lines changed

pymongo/mongo_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2328,7 +2328,7 @@ def __init__(self, client, server, session):
23282328
# "Note that when a network error occurs before the handshake
23292329
# completes then the error's generation number is the generation
23302330
# of the pool at the time the connection attempt was started."
2331-
self.sock_generation = server.pool.generation
2331+
self.sock_generation = server.pool.gen.get_overall()
23322332
self.completed_handshake = False
23332333
self.service_id = None
23342334

pymongo/pool.py

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -549,7 +549,8 @@ def __init__(self, sock, pool, address, id):
549549

550550
# The pool's generation changes with each reset() so we can close
551551
# sockets created before the last reset.
552-
self.generation = pool.generation
552+
self.pool_gen = pool.gen
553+
self.generation = self.pool_gen.get_overall()
553554
self.ready = False
554555
self.cancel_context = None
555556
if not pool.handshake:
@@ -635,6 +636,7 @@ def _ismaster(self, cluster_time, topology_version,
635636
'Driver attempted to initialize in load balancing mode'
636637
' but the server does not support this mode')
637638
self.service_id = ismaster.service_id
639+
self.generation = self.pool_gen.get(self.service_id)
638640
return ismaster
639641

640642
def _next_reply(self):
@@ -1075,6 +1077,43 @@ class _PoolClosedError(PyMongoError):
10751077
pass
10761078

10771079

1080+
class _PoolGeneration(object):
1081+
def __init__(self):
1082+
# Maps service_id to generation.
1083+
self._generations = collections.defaultdict(int)
1084+
# Overall pool generation.
1085+
self._generation = 0
1086+
1087+
def get(self, service_id):
1088+
"""Get the generation for the given service_id."""
1089+
if service_id is None:
1090+
return self._generation
1091+
return self._generations[service_id]
1092+
1093+
def get_overall(self):
1094+
"""Get the Pool's overall generation."""
1095+
return self._generation
1096+
1097+
def inc(self, service_id):
1098+
"""Increment the generation for the given service_id."""
1099+
self._generation += 1
1100+
if service_id is None:
1101+
for service_id in self._generations:
1102+
self._generations[service_id] += 1
1103+
else:
1104+
self._generations[service_id] += 1
1105+
1106+
def stale(self, gen, service_id):
1107+
"""Return if the given generation for a given service_id is stale."""
1108+
return gen != self.get(service_id)
1109+
1110+
1111+
class PoolState(object):
1112+
PAUSED = 1
1113+
READY = 2
1114+
CLOSED = 3
1115+
1116+
10781117
# Do *not* explicitly inherit from object or Jython won't call __del__
10791118
# http://bugs.jython.org/issue1057
10801119
class Pool:
@@ -1102,7 +1141,8 @@ def __init__(self, address, options, handshake=True):
11021141

11031142
# Keep track of resets, so we notice sockets created before the most
11041143
# recent reset and close them.
1105-
self.generation = 0
1144+
# self.generation = 0
1145+
self.gen = _PoolGeneration()
11061146
self.pid = os.getpid()
11071147
self.address = address
11081148
self.opts = options
@@ -1130,10 +1170,22 @@ def _reset(self, close, service_id=None):
11301170
with self.lock:
11311171
if self.closed:
11321172
return
1133-
self.generation += 1
1173+
self.gen.inc(service_id)
11341174
self.pid = os.getpid()
1135-
sockets, self.sockets = self.sockets, collections.deque()
11361175
self.active_sockets = 0
1176+
if service_id is None:
1177+
sockets, self.sockets = self.sockets, collections.deque()
1178+
else:
1179+
discard = collections.deque()
1180+
keep = collections.deque()
1181+
for sock_info in self.sockets:
1182+
if sock_info.service_id == service_id:
1183+
discard.append(sock_info)
1184+
else:
1185+
keep.append(sock_info)
1186+
sockets = discard
1187+
self.sockets = keep
1188+
11371189
if close:
11381190
self.closed = True
11391191

@@ -1168,6 +1220,9 @@ def reset(self, service_id=None):
11681220
def close(self):
11691221
self._reset(close=True)
11701222

1223+
def stale_generation(self, gen, service_id):
1224+
return self.gen.stale(gen, service_id)
1225+
11711226
def remove_stale_sockets(self, reference_generation, all_credentials):
11721227
"""Removes stale sockets then adds new ones if pool is too small and
11731228
has not been reset. The `reference_generation` argument specifies the
@@ -1196,7 +1251,7 @@ def remove_stale_sockets(self, reference_generation, all_credentials):
11961251
with self.lock:
11971252
# Close connection and return if the pool was reset during
11981253
# socket creation or while acquiring the pool lock.
1199-
if self.generation != reference_generation:
1254+
if self.gen.get_overall() != reference_generation:
12001255
sock_info.close_socket(ConnectionClosedReason.STALE)
12011256
break
12021257
self.sockets.appendleft(sock_info)
@@ -1361,7 +1416,8 @@ def return_socket(self, sock_info):
13611416
with self.lock:
13621417
# Hold the lock to ensure this section does not race with
13631418
# Pool.reset().
1364-
if sock_info.generation != self.generation:
1419+
if self.stale_generation(sock_info.generation,
1420+
sock_info.service_id):
13651421
sock_info.close_socket(ConnectionClosedReason.STALE)
13661422
else:
13671423
sock_info.update_last_checkin_time()
@@ -1400,7 +1456,7 @@ def _perished(self, sock_info):
14001456
sock_info.close_socket(ConnectionClosedReason.ERROR)
14011457
return True
14021458

1403-
if sock_info.generation != self.generation:
1459+
if self.stale_generation(sock_info.generation, sock_info.service_id):
14041460
sock_info.close_socket(ConnectionClosedReason.STALE)
14051461
return True
14061462

pymongo/topology.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,8 @@ def update_pool(self, all_credentials):
449449
# Only update pools for data-bearing servers.
450450
for sd in self.data_bearing_servers():
451451
server = self._servers[sd.address]
452-
servers.append((server, server.pool.generation))
452+
servers.append((server,
453+
server.pool.gen.get_overall()))
453454

454455
for server, generation in servers:
455456
server.pool.remove_stale_sockets(generation, all_credentials)
@@ -574,7 +575,8 @@ def _is_stale_error(self, address, err_ctx):
574575
# Another thread removed this server from the topology.
575576
return True
576577

577-
if err_ctx.sock_generation != server._pool.generation:
578+
if server._pool.stale_generation(
579+
err_ctx.sock_generation, err_ctx.service_id):
578580
# This is an outdated error from a previous pool version.
579581
return True
580582

test/test_client.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,7 +1554,7 @@ def test_reset_during_update_pool(self):
15541554
self.addCleanup(client.close)
15551555
client.admin.command('ping')
15561556
pool = get_pool(client)
1557-
generation = pool.generation
1557+
generation = pool.gen.get_overall()
15581558

15591559
# Continuously reset the pool.
15601560
class ResetPoolThread(threading.Thread):
@@ -1568,7 +1568,10 @@ def stop(self):
15681568

15691569
def run(self):
15701570
while self.running:
1571-
self.pool.reset()
1571+
exc = AutoReconnect('mock pool error')
1572+
ctx = _ErrorContext(
1573+
exc, 0, pool.gen.get_overall(), False, None)
1574+
client._topology.handle_error(pool.address, ctx)
15721575
time.sleep(0.001)
15731576

15741577
t = ResetPoolThread(pool)
@@ -1581,7 +1584,7 @@ def run(self):
15811584
for _ in range(10):
15821585
client._topology.update_pool(
15831586
client._MongoClient__all_credentials)
1584-
if generation != pool.generation:
1587+
if generation != pool.gen.get_overall():
15851588
break
15861589
finally:
15871590
t.stop()

test/test_discovery_and_monitoring.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ def got_app_error(topology, app_error):
105105
server_address = common.partition_node(app_error['address'])
106106
server = topology.get_server_by_address(server_address)
107107
error_type = app_error['type']
108-
generation = app_error.get('generation', server.pool.generation)
108+
generation = app_error.get(
109+
'generation', server.pool.gen.get_overall())
109110
when = app_error['when']
110111
max_wire_version = app_error['maxWireVersion']
111112
# XXX: We could get better test coverage by mocking the errors on the
@@ -199,7 +200,7 @@ def check_outcome(self, topology, outcome):
199200
if expected_pool:
200201
self.assertEqual(
201202
expected_pool.get('generation'),
202-
actual_server.pool.generation)
203+
actual_server.pool.gen.get_overall())
203204

204205
self.assertEqual(outcome['setName'], topology.description.replica_set_name)
205206
self.assertEqual(outcome.get('logicalSessionTimeoutMinutes'),
@@ -288,7 +289,7 @@ def test_ignore_stale_connection_errors(self):
288289
# Wait for initial discovery.
289290
client.admin.command('ping')
290291
pool = get_pool(client)
291-
starting_generation = pool.generation
292+
starting_generation = pool.gen.get_overall()
292293
wait_until(lambda: len(pool.sockets) == N_THREADS, 'created sockets')
293294

294295
def mock_command(*args, **kwargs):
@@ -314,7 +315,8 @@ def insert_command(i):
314315
t.join()
315316

316317
# Expect a single pool reset for the network error
317-
self.assertEqual(starting_generation+1, pool.generation)
318+
self.assertEqual(
319+
starting_generation+1, pool.gen.get_overall())
318320

319321
# Server should be selectable.
320322
client.admin.command('ping')

test/test_topology.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -718,11 +718,11 @@ def _check_with_socket(self, *args, **kwargs):
718718
self.addCleanup(t.close)
719719
server = wait_for_master(t)
720720
self.assertEqual(1, ismaster_count[0])
721-
generation = server.pool.generation
721+
generation = server.pool.gen.get_overall()
722722

723723
# Pool is reset by ismaster failure.
724724
t.request_check_all()
725-
self.assertNotEqual(generation, server.pool.generation)
725+
self.assertNotEqual(generation, server.pool.gen.get_overall())
726726

727727
def test_ismaster_retry(self):
728728
# ismaster succeeds at first, then raises socket error, then succeeds.

test/unified_format.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,16 @@ def is_run_on_requirement_satisfied(requirement):
123123
elif client_context.server_parameters[param] != val:
124124
params_satisfied = False
125125

126+
auth_satisfied = True
127+
req_auth = requirement.get('auth')
128+
if req_auth is not None:
129+
if req_auth:
130+
auth_satisfied = client_context.auth_enabled
131+
else:
132+
auth_satisfied = not client_context.auth_enabled
133+
126134
return (topology_satisfied and min_version_satisfied and
127-
max_version_satisfied and params_satisfied)
135+
max_version_satisfied and params_satisfied and auth_satisfied)
128136

129137

130138
def parse_collection_or_database_options(options):

test/utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@
4141
from pymongo.errors import ConfigurationError, OperationFailure
4242
from pymongo.monitoring import _SENSITIVE_COMMANDS
4343
from pymongo.pool import (_CancellationContext,
44-
PoolOptions)
44+
PoolOptions,
45+
_PoolGeneration)
4546
from pymongo.read_concern import ReadConcern
4647
from pymongo.read_preferences import ReadPreference
4748
from pymongo.server_selectors import (any_server_selector,
@@ -264,20 +265,23 @@ def __exit__(self, exc_type, exc_val, exc_tb):
264265

265266

266267
class MockPool(object):
267-
def __init__(self, *args, **kwargs):
268-
self.generation = 0
268+
def __init__(self, address, options, handshake=True):
269+
self.gen = _PoolGeneration()
269270
self._lock = threading.Lock()
270271
self.opts = PoolOptions()
271272

273+
def stale_generation(self, gen, service_id):
274+
return self.gen.stale(gen, service_id)
275+
272276
def get_socket(self, all_credentials, checkout=False):
273277
return MockSocketInfo()
274278

275279
def return_socket(self, *args, **kwargs):
276280
pass
277281

278-
def _reset(self):
282+
def _reset(self, service_id=None):
279283
with self._lock:
280-
self.generation += 1
284+
self.gen.inc(service_id)
281285

282286
def reset(self, service_id=None):
283287
self._reset()

0 commit comments

Comments
 (0)