Skip to content

Commit 64f7706

Browse files
authored
PYTHON-1438 Mark a server unknown when connection handshake fails with a network timeout error (#461)
1 parent 3c1dd61 commit 64f7706

File tree

4 files changed

+20
-9
lines changed

4 files changed

+20
-9
lines changed

pymongo/mongo_client.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2258,7 +2258,7 @@ def _add_retryable_write_error(exc, max_wire_version):
22582258
class _MongoClientErrorHandler(object):
22592259
"""Handle errors raised when executing an operation."""
22602260
__slots__ = ('client', 'server_address', 'session', 'max_wire_version',
2261-
'sock_generation')
2261+
'sock_generation', 'completed_handshake')
22622262

22632263
def __init__(self, client, server, session):
22642264
self.client = client
@@ -2270,11 +2270,13 @@ def __init__(self, client, server, session):
22702270
# completes then the error's generation number is the generation
22712271
# of the pool at the time the connection attempt was started."
22722272
self.sock_generation = server.pool.generation
2273+
self.completed_handshake = False
22732274

22742275
def contribute_socket(self, sock_info):
22752276
"""Provide socket information to the error handler."""
22762277
self.max_wire_version = sock_info.max_wire_version
22772278
self.sock_generation = sock_info.generation
2279+
self.completed_handshake = True
22782280

22792281
def __enter__(self):
22802282
return self
@@ -2295,5 +2297,6 @@ def __exit__(self, exc_type, exc_val, exc_tb):
22952297
self.session._unpin_mongos()
22962298

22972299
err_ctx = _ErrorContext(
2298-
exc_val, self.max_wire_version, self.sock_generation)
2300+
exc_val, self.max_wire_version, self.sock_generation,
2301+
self.completed_handshake)
22992302
self.client._topology.handle_error(self.server_address, err_ctx)

pymongo/topology.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,8 @@ def _handle_error(self, address, err_ctx):
567567
server = self._servers[address]
568568
error = err_ctx.error
569569
exc_type = type(error)
570-
if issubclass(exc_type, NetworkTimeout):
570+
if (issubclass(exc_type, NetworkTimeout) and
571+
err_ctx.completed_handshake):
571572
# The socket has been closed. Don't reset the server.
572573
# Server Discovery And Monitoring Spec: "When an application
573574
# operation fails because of any network error besides a socket
@@ -750,10 +751,12 @@ def __repr__(self):
750751

751752
class _ErrorContext(object):
752753
"""An error with context for SDAM error handling."""
753-
def __init__(self, error, max_wire_version, sock_generation):
754+
def __init__(self, error, max_wire_version, sock_generation,
755+
completed_handshake):
754756
self.error = error
755757
self.max_wire_version = max_wire_version
756758
self.sock_generation = sock_generation
759+
self.completed_handshake = completed_handshake
757760

758761

759762
def _is_stale_error_topology_version(current_tv, error_tv):

test/test_discovery_and_monitoring.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,16 @@ def got_app_error(topology, app_error):
119119
raise AssertionError('unknown error type: %s' % (error_type,))
120120
assert False
121121
except (AutoReconnect, NotMasterError, OperationFailure) as e:
122-
if when == 'beforeHandshakeCompletes' and error_type == 'timeout':
123-
raise unittest.SkipTest('PYTHON-2211')
122+
if when == 'beforeHandshakeCompletes':
123+
completed_handshake = False
124+
elif when == 'afterHandshakeCompletes':
125+
completed_handshake = True
126+
else:
127+
assert False, 'Unknown when field %s' % (when,)
124128

125129
topology.handle_error(
126-
server_address, _ErrorContext(e, max_wire_version, generation))
130+
server_address, _ErrorContext(e, max_wire_version, generation,
131+
completed_handshake))
127132

128133

129134
def get_type(topology, hostname):

test/test_topology.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -419,7 +419,7 @@ def test_handle_error(self):
419419
'setName': 'rs',
420420
'hosts': ['a', 'b']})
421421

422-
errctx = _ErrorContext(AutoReconnect('mock'), 0, 0)
422+
errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True)
423423
t.handle_error(('a', 27017), errctx)
424424
self.assertEqual(SERVER_TYPE.Unknown, get_type(t, 'a'))
425425
self.assertEqual(SERVER_TYPE.RSSecondary, get_type(t, 'b'))
@@ -480,7 +480,7 @@ def test_handle_error_removed_server(self):
480480
t = create_mock_topology(replica_set_name='rs')
481481

482482
# No error resetting a server not in the TopologyDescription.
483-
errctx = _ErrorContext(AutoReconnect('mock'), 0, 0)
483+
errctx = _ErrorContext(AutoReconnect('mock'), 0, 0, True)
484484
t.handle_error(('b', 27017), errctx)
485485

486486
# Server was *not* added as type Unknown.

0 commit comments

Comments
 (0)