Skip to content

Commit 6a47fc6

Browse files
committed
address review
1 parent f1294dc commit 6a47fc6

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

pymongo/asynchronous/pool.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1075,8 +1075,13 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
10751075
if self._backoff:
10761076
await asyncio.sleep(_backoff(self._backoff))
10771077

1078+
# Pass a context to determine if we successfully create a configured socket.
1079+
context = dict(has_created_socket=False)
1080+
10781081
try:
1079-
networking_interface = await _configured_protocol_interface(self.address, self.opts)
1082+
networking_interface = await _configured_protocol_interface(
1083+
self.address, self.opts, context
1084+
)
10801085
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
10811086
except BaseException as error:
10821087
async with self.lock:
@@ -1097,7 +1102,8 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> A
10971102
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
10981103
error=ConnectionClosedReason.ERROR,
10991104
)
1100-
self._handle_connection_error(error, "handshake")
1105+
if context["has_created_socket"]:
1106+
self._handle_connection_error(error, "handshake")
11011107
if isinstance(error, (IOError, OSError, *SSLErrors)):
11021108
details = _get_timeout_details(self.opts)
11031109
_raise_connection_failure(self.address, error, timeout_details=details)

pymongo/pool_shared.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,7 @@ async def _configured_protocol_interface(
250250
address: _Address,
251251
options: PoolOptions,
252252
protocol_kls: type[PyMongoBaseProtocol] = PyMongoProtocol,
253+
context: dict[str, bool] | None = None,
253254
) -> AsyncNetworkingInterface:
254255
"""Given (host, port) and PoolOptions, return a configured AsyncNetworkingInterface.
255256
@@ -261,6 +262,10 @@ async def _configured_protocol_interface(
261262
ssl_context = options._ssl_context
262263
timeout = options.socket_timeout
263264

265+
# Signal that we have created the socket successfully.
266+
if context:
267+
context["has_created_socket"] = True
268+
264269
if ssl_context is None:
265270
return AsyncNetworkingInterface(
266271
await asyncio.get_running_loop().create_connection(
@@ -374,7 +379,7 @@ def _create_connection(address: _Address, options: PoolOptions) -> socket.socket
374379

375380

376381
def _configured_socket_interface(
377-
address: _Address, options: PoolOptions, *args: Any
382+
address: _Address, options: PoolOptions, *args: Any, context: dict[str, bool] | None = None
378383
) -> NetworkingInterface:
379384
"""Given (host, port) and PoolOptions, return a NetworkingInterface wrapping a configured socket.
380385
@@ -385,6 +390,10 @@ def _configured_socket_interface(
385390
sock = _create_connection(address, options)
386391
ssl_context = options._ssl_context
387392

393+
# Signal that we have created the socket successfully.
394+
if context:
395+
context["has_created_socket"] = True
396+
388397
if ssl_context is None:
389398
sock.settimeout(options.socket_timeout)
390399
return NetworkingInterface(sock)

pymongo/synchronous/pool.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,8 +1071,11 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect
10711071
if self._backoff:
10721072
time.sleep(_backoff(self._backoff))
10731073

1074+
# Pass a context to determine if we successfully create a configured socket.
1075+
context = dict(has_created_socket=False)
1076+
10741077
try:
1075-
networking_interface = _configured_socket_interface(self.address, self.opts)
1078+
networking_interface = _configured_socket_interface(self.address, self.opts, context)
10761079
# Catch KeyboardInterrupt, CancelledError, etc. and cleanup.
10771080
except BaseException as error:
10781081
with self.lock:
@@ -1093,7 +1096,8 @@ def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connect
10931096
reason=_verbose_connection_error_reason(ConnectionClosedReason.ERROR),
10941097
error=ConnectionClosedReason.ERROR,
10951098
)
1096-
self._handle_connection_error(error, "handshake")
1099+
if context["has_created_socket"]:
1100+
self._handle_connection_error(error, "handshake")
10971101
if isinstance(error, (IOError, OSError, *SSLErrors)):
10981102
details = _get_timeout_details(self.opts)
10991103
_raise_connection_failure(self.address, error, timeout_details=details)

0 commit comments

Comments
 (0)