Skip to content

Commit 4c727fd

Browse files
committed
PYTHON-2158 Support mechanism negotiation on the connection handshake
1 parent 71d1227 commit 4c727fd

File tree

7 files changed

+69
-40
lines changed

7 files changed

+69
-40
lines changed

pymongo/auth.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -562,12 +562,16 @@ def _authenticate_mongo_cr(credentials, sock_info):
562562

563563
def _authenticate_default(credentials, sock_info):
564564
if sock_info.max_wire_version >= 7:
565-
source = credentials.source
566-
cmd = SON([
567-
('ismaster', 1),
568-
('saslSupportedMechs', source + '.' + credentials.username)])
569-
mechs = sock_info.command(
570-
source, cmd, publish_events=False).get('saslSupportedMechs', [])
565+
if credentials in sock_info.negotiated_mechanisms:
566+
mechs = sock_info.negotiated_mechanisms[credentials]
567+
else:
568+
source = credentials.source
569+
cmd = SON([
570+
('ismaster', 1),
571+
('saslSupportedMechs', source + '.' + credentials.username)])
572+
mechs = sock_info.command(
573+
source, cmd, publish_events=False).get(
574+
'saslSupportedMechs', [])
571575
if 'SCRAM-SHA-256' in mechs:
572576
return _authenticate_scram(credentials, sock_info, 'SCRAM-SHA-256')
573577
else:

pymongo/ismaster.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,15 @@ def last_write_date(self):
156156
@property
157157
def compressors(self):
158158
return self._doc.get('compression')
159+
160+
@property
161+
def sasl_supported_mechs(self):
162+
"""Supported authentication mechanisms for the current user.
163+
164+
For example::
165+
166+
>>> ismaster.sasl_supported_mechs
167+
["SCRAM-SHA-1", "SCRAM-SHA-256"]
168+
169+
"""
170+
return self._doc.get('saslSupportedMechs', [])

pymongo/mongo_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1751,7 +1751,7 @@ def _process_periodic_tasks(self):
17511751
maintain connection pool parameters."""
17521752
self._process_kill_cursors()
17531753
try:
1754-
self._topology.update_pool()
1754+
self._topology.update_pool(self.__all_credentials)
17551755
except Exception:
17561756
helpers._handle_exception()
17571757

pymongo/pool.py

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,16 @@ def metadata(self):
458458
return self.__metadata.copy()
459459

460460

461+
def _negotiate_creds(all_credentials):
462+
"""Return one credential that needs mechanism negotiation, if any.
463+
"""
464+
if all_credentials:
465+
for creds in all_credentials.values():
466+
if creds.mechanism == 'DEFAULT' and creds.username:
467+
return creds
468+
return None
469+
470+
461471
class SocketInfo(object):
462472
"""Store a socket with some metadata.
463473
@@ -488,13 +498,16 @@ def __init__(self, sock, pool, address, id):
488498
self.compression_settings = pool.opts.compression_settings
489499
self.compression_context = None
490500
self.socket_checker = SocketChecker()
501+
# Support for mechanism negotiation on the initial handshake.
502+
# Maps credential to saslSupportedMechs.
503+
self.negotiated_mechanisms = {}
491504

492505
# The pool's generation changes with each reset() so we can close
493506
# sockets created before the last reset.
494507
self.generation = pool.generation
495508
self.ready = False
496509

497-
def ismaster(self, metadata, cluster_time):
510+
def ismaster(self, metadata, cluster_time, all_credentials=None):
498511
cmd = SON([('ismaster', 1)])
499512
if not self.performed_handshake:
500513
cmd['client'] = metadata
@@ -504,6 +517,12 @@ def ismaster(self, metadata, cluster_time):
504517
if self.max_wire_version >= 6 and cluster_time is not None:
505518
cmd['$clusterTime'] = cluster_time
506519

520+
# XXX: Simplify in PyMongo 4.0 when all_credentials is always a single
521+
# unchangeable value per MongoClient.
522+
creds = _negotiate_creds(all_credentials)
523+
if creds:
524+
cmd['saslSupportedMechs'] = creds.source + '.' + creds.username
525+
507526
ismaster = IsMaster(self.command('admin', cmd, publish_events=False))
508527
self.is_writable = ismaster.is_writable
509528
self.max_wire_version = ismaster.max_wire_version
@@ -520,6 +539,8 @@ def ismaster(self, metadata, cluster_time):
520539

521540
self.performed_handshake = True
522541
self.op_msg_enabled = ismaster.max_wire_version >= 6
542+
if creds:
543+
self.negotiated_mechanisms[creds] = ismaster.sasl_supported_mechs
523544
return ismaster
524545

525546
def command(self, dbname, spec, slave_ok=False,
@@ -701,8 +722,7 @@ def check_auth(self, all_credentials):
701722
self.authset.discard(credentials)
702723

703724
for credentials in cached - authset:
704-
auth.authenticate(credentials, self)
705-
self.authset.add(credentials)
725+
self.authenticate(credentials)
706726

707727
# CMAP spec says to publish the ready event only after authenticating
708728
# the connection.
@@ -721,6 +741,8 @@ def authenticate(self, credentials):
721741
"""
722742
auth.authenticate(credentials, self)
723743
self.authset.add(credentials)
744+
# negotiated_mechanisms are no longer needed.
745+
self.negotiated_mechanisms.pop(credentials, None)
724746

725747
def validate_session(self, client, session):
726748
"""Validate this session before use with client.
@@ -1026,7 +1048,7 @@ def reset(self):
10261048
def close(self):
10271049
self._reset(close=True)
10281050

1029-
def remove_stale_sockets(self, reference_generation):
1051+
def remove_stale_sockets(self, reference_generation, all_credentials):
10301052
"""Removes stale sockets then adds new ones if pool is too small and
10311053
has not been reset. The `reference_generation` argument specifies the
10321054
`generation` at the point in time this operation was requested on the
@@ -1050,7 +1072,7 @@ def remove_stale_sockets(self, reference_generation):
10501072
if not self._socket_semaphore.acquire(False):
10511073
break
10521074
try:
1053-
sock_info = self.connect()
1075+
sock_info = self.connect(all_credentials)
10541076
with self.lock:
10551077
# Close connection and return if the pool was reset during
10561078
# socket creation or while acquiring the pool lock.
@@ -1061,7 +1083,7 @@ def remove_stale_sockets(self, reference_generation):
10611083
finally:
10621084
self._socket_semaphore.release()
10631085

1064-
def connect(self):
1086+
def connect(self, all_credentials=None):
10651087
"""Connect to Mongo and return a new SocketInfo.
10661088
10671089
Can raise ConnectionFailure or CertificateError.
@@ -1081,9 +1103,6 @@ def connect(self):
10811103
try:
10821104
sock = _configured_socket(self.address, self.opts)
10831105
except socket.error as error:
1084-
if sock is not None:
1085-
sock.close()
1086-
10871106
if self.enabled_for_cmap:
10881107
listeners.publish_connection_closed(
10891108
self.address, conn_id, ConnectionClosedReason.ERROR)
@@ -1092,7 +1111,7 @@ def connect(self):
10921111

10931112
sock_info = SocketInfo(sock, self, self.address, conn_id)
10941113
if self.handshake:
1095-
sock_info.ismaster(self.opts.metadata, None)
1114+
sock_info.ismaster(self.opts.metadata, None, all_credentials)
10961115
self.is_writable = sock_info.is_writable
10971116

10981117
return sock_info
@@ -1123,29 +1142,23 @@ def get_socket(self, all_credentials, checkout=False):
11231142
listeners = self.opts.event_listeners
11241143
if self.enabled_for_cmap:
11251144
listeners.publish_connection_check_out_started(self.address)
1126-
# First get a socket, then attempt authentication. Simplifies
1127-
# semaphore management in the face of network errors during auth.
1128-
sock_info = self._get_socket_no_auth()
1129-
checked_auth = False
1145+
1146+
sock_info = self._get_socket(all_credentials)
1147+
1148+
if self.enabled_for_cmap:
1149+
listeners.publish_connection_checked_out(
1150+
self.address, sock_info.id)
11301151
try:
1131-
sock_info.check_auth(all_credentials)
1132-
checked_auth = True
1133-
if self.enabled_for_cmap:
1134-
listeners.publish_connection_checked_out(
1135-
self.address, sock_info.id)
11361152
yield sock_info
11371153
except:
11381154
# Exception in caller. Decrement semaphore.
1139-
self.return_socket(sock_info, publish_checkin=checked_auth)
1140-
if self.enabled_for_cmap and not checked_auth:
1141-
self.opts.event_listeners.publish_connection_check_out_failed(
1142-
self.address, ConnectionCheckOutFailedReason.CONN_ERROR)
1155+
self.return_socket(sock_info)
11431156
raise
11441157
else:
11451158
if not checkout:
11461159
self.return_socket(sock_info)
11471160

1148-
def _get_socket_no_auth(self):
1161+
def _get_socket(self, all_credentials):
11491162
"""Get or create a SocketInfo. Can raise ConnectionFailure."""
11501163
# We use the pid here to avoid issues with fork / multiprocessing.
11511164
# See test.test_client:TestClient.test_fork for an example of
@@ -1177,10 +1190,11 @@ def _get_socket_no_auth(self):
11771190
sock_info = self.sockets.popleft()
11781191
except IndexError:
11791192
# Can raise ConnectionFailure or CertificateError.
1180-
sock_info = self.connect()
1193+
sock_info = self.connect(all_credentials)
11811194
else:
11821195
if self._perished(sock_info):
11831196
sock_info = None
1197+
sock_info.check_auth(all_credentials)
11841198
except Exception:
11851199
self._socket_semaphore.release()
11861200
with self.lock:
@@ -1193,16 +1207,14 @@ def _get_socket_no_auth(self):
11931207

11941208
return sock_info
11951209

1196-
def return_socket(self, sock_info, publish_checkin=True):
1210+
def return_socket(self, sock_info):
11971211
"""Return the socket to the pool, or if it's closed discard it.
11981212
11991213
:Parameters:
12001214
- `sock_info`: The socket to check into the pool.
1201-
- `publish_checkin`: If False, a ConnectionCheckedInEvent will not
1202-
be published.
12031215
"""
12041216
listeners = self.opts.event_listeners
1205-
if self.enabled_for_cmap and publish_checkin:
1217+
if self.enabled_for_cmap:
12061218
listeners.publish_connection_checked_in(self.address, sock_info.id)
12071219
if self.pid != os.getpid():
12081220
self.reset()

pymongo/topology.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,15 +430,15 @@ def mark_server_unknown_and_request_check(self, address, error):
430430
self._reset_server(address, reset_pool=False, error=error)
431431
self._request_check(address)
432432

433-
def update_pool(self):
433+
def update_pool(self, all_credentials):
434434
# Remove any stale sockets and add new sockets if pool is too small.
435435
servers = []
436436
with self._lock:
437437
for server in self._servers.values():
438438
servers.append((server, server._pool.generation))
439439

440440
for server, generation in servers:
441-
server._pool.remove_stale_sockets(generation)
441+
server._pool.remove_stale_sockets(generation, all_credentials)
442442

443443
def close(self):
444444
"""Clear pools and terminate monitors. Topology reopens on demand."""

test/test_client.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1487,7 +1487,8 @@ def run(self):
14871487
try:
14881488
while True:
14891489
for _ in range(10):
1490-
client._topology.update_pool()
1490+
client._topology.update_pool(
1491+
client._MongoClient__all_credentials)
14911492
if generation != pool.generation:
14921493
break
14931494
finally:

test/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def close(self):
230230
def update_is_writable(self, is_writable):
231231
pass
232232

233-
def remove_stale_sockets(self, reference_generation):
233+
def remove_stale_sockets(self, *args, **kwargs):
234234
pass
235235

236236

0 commit comments

Comments
 (0)