@@ -458,6 +458,16 @@ def metadata(self):
458
458
return self .__metadata .copy ()
459
459
460
460
461
+ def _negotiate_creds (all_credentials ):
462
+ """Return one credential that needs mechanism negotiation, if any.
463
+ """
464
+ if all_credentials :
465
+ for creds in all_credentials .values ():
466
+ if creds .mechanism == 'DEFAULT' and creds .username :
467
+ return creds
468
+ return None
469
+
470
+
461
471
class SocketInfo (object ):
462
472
"""Store a socket with some metadata.
463
473
@@ -488,13 +498,16 @@ def __init__(self, sock, pool, address, id):
488
498
self .compression_settings = pool .opts .compression_settings
489
499
self .compression_context = None
490
500
self .socket_checker = SocketChecker ()
501
+ # Support for mechanism negotiation on the initial handshake.
502
+ # Maps credential to saslSupportedMechs.
503
+ self .negotiated_mechanisms = {}
491
504
492
505
# The pool's generation changes with each reset() so we can close
493
506
# sockets created before the last reset.
494
507
self .generation = pool .generation
495
508
self .ready = False
496
509
497
- def ismaster (self , metadata , cluster_time ):
510
+ def ismaster (self , metadata , cluster_time , all_credentials = None ):
498
511
cmd = SON ([('ismaster' , 1 )])
499
512
if not self .performed_handshake :
500
513
cmd ['client' ] = metadata
@@ -504,6 +517,12 @@ def ismaster(self, metadata, cluster_time):
504
517
if self .max_wire_version >= 6 and cluster_time is not None :
505
518
cmd ['$clusterTime' ] = cluster_time
506
519
520
+ # XXX: Simplify in PyMongo 4.0 when all_credentials is always a single
521
+ # unchangeable value per MongoClient.
522
+ creds = _negotiate_creds (all_credentials )
523
+ if creds :
524
+ cmd ['saslSupportedMechs' ] = creds .source + '.' + creds .username
525
+
507
526
ismaster = IsMaster (self .command ('admin' , cmd , publish_events = False ))
508
527
self .is_writable = ismaster .is_writable
509
528
self .max_wire_version = ismaster .max_wire_version
@@ -520,6 +539,8 @@ def ismaster(self, metadata, cluster_time):
520
539
521
540
self .performed_handshake = True
522
541
self .op_msg_enabled = ismaster .max_wire_version >= 6
542
+ if creds :
543
+ self .negotiated_mechanisms [creds ] = ismaster .sasl_supported_mechs
523
544
return ismaster
524
545
525
546
def command (self , dbname , spec , slave_ok = False ,
@@ -701,8 +722,7 @@ def check_auth(self, all_credentials):
701
722
self .authset .discard (credentials )
702
723
703
724
for credentials in cached - authset :
704
- auth .authenticate (credentials , self )
705
- self .authset .add (credentials )
725
+ self .authenticate (credentials )
706
726
707
727
# CMAP spec says to publish the ready event only after authenticating
708
728
# the connection.
@@ -721,6 +741,8 @@ def authenticate(self, credentials):
721
741
"""
722
742
auth .authenticate (credentials , self )
723
743
self .authset .add (credentials )
744
+ # negotiated_mechanisms are no longer needed.
745
+ self .negotiated_mechanisms .pop (credentials , None )
724
746
725
747
def validate_session (self , client , session ):
726
748
"""Validate this session before use with client.
@@ -1026,7 +1048,7 @@ def reset(self):
1026
1048
def close (self ):
1027
1049
self ._reset (close = True )
1028
1050
1029
- def remove_stale_sockets (self , reference_generation ):
1051
+ def remove_stale_sockets (self , reference_generation , all_credentials ):
1030
1052
"""Removes stale sockets then adds new ones if pool is too small and
1031
1053
has not been reset. The `reference_generation` argument specifies the
1032
1054
`generation` at the point in time this operation was requested on the
@@ -1050,7 +1072,7 @@ def remove_stale_sockets(self, reference_generation):
1050
1072
if not self ._socket_semaphore .acquire (False ):
1051
1073
break
1052
1074
try :
1053
- sock_info = self .connect ()
1075
+ sock_info = self .connect (all_credentials )
1054
1076
with self .lock :
1055
1077
# Close connection and return if the pool was reset during
1056
1078
# socket creation or while acquiring the pool lock.
@@ -1061,7 +1083,7 @@ def remove_stale_sockets(self, reference_generation):
1061
1083
finally :
1062
1084
self ._socket_semaphore .release ()
1063
1085
1064
- def connect (self ):
1086
+ def connect (self , all_credentials = None ):
1065
1087
"""Connect to Mongo and return a new SocketInfo.
1066
1088
1067
1089
Can raise ConnectionFailure or CertificateError.
@@ -1081,9 +1103,6 @@ def connect(self):
1081
1103
try :
1082
1104
sock = _configured_socket (self .address , self .opts )
1083
1105
except socket .error as error :
1084
- if sock is not None :
1085
- sock .close ()
1086
-
1087
1106
if self .enabled_for_cmap :
1088
1107
listeners .publish_connection_closed (
1089
1108
self .address , conn_id , ConnectionClosedReason .ERROR )
@@ -1092,7 +1111,7 @@ def connect(self):
1092
1111
1093
1112
sock_info = SocketInfo (sock , self , self .address , conn_id )
1094
1113
if self .handshake :
1095
- sock_info .ismaster (self .opts .metadata , None )
1114
+ sock_info .ismaster (self .opts .metadata , None , all_credentials )
1096
1115
self .is_writable = sock_info .is_writable
1097
1116
1098
1117
return sock_info
@@ -1123,29 +1142,23 @@ def get_socket(self, all_credentials, checkout=False):
1123
1142
listeners = self .opts .event_listeners
1124
1143
if self .enabled_for_cmap :
1125
1144
listeners .publish_connection_check_out_started (self .address )
1126
- # First get a socket, then attempt authentication. Simplifies
1127
- # semaphore management in the face of network errors during auth.
1128
- sock_info = self ._get_socket_no_auth ()
1129
- checked_auth = False
1145
+
1146
+ sock_info = self ._get_socket (all_credentials )
1147
+
1148
+ if self .enabled_for_cmap :
1149
+ listeners .publish_connection_checked_out (
1150
+ self .address , sock_info .id )
1130
1151
try :
1131
- sock_info .check_auth (all_credentials )
1132
- checked_auth = True
1133
- if self .enabled_for_cmap :
1134
- listeners .publish_connection_checked_out (
1135
- self .address , sock_info .id )
1136
1152
yield sock_info
1137
1153
except :
1138
1154
# Exception in caller. Decrement semaphore.
1139
- self .return_socket (sock_info , publish_checkin = checked_auth )
1140
- if self .enabled_for_cmap and not checked_auth :
1141
- self .opts .event_listeners .publish_connection_check_out_failed (
1142
- self .address , ConnectionCheckOutFailedReason .CONN_ERROR )
1155
+ self .return_socket (sock_info )
1143
1156
raise
1144
1157
else :
1145
1158
if not checkout :
1146
1159
self .return_socket (sock_info )
1147
1160
1148
- def _get_socket_no_auth (self ):
1161
+ def _get_socket (self , all_credentials ):
1149
1162
"""Get or create a SocketInfo. Can raise ConnectionFailure."""
1150
1163
# We use the pid here to avoid issues with fork / multiprocessing.
1151
1164
# See test.test_client:TestClient.test_fork for an example of
@@ -1177,10 +1190,11 @@ def _get_socket_no_auth(self):
1177
1190
sock_info = self .sockets .popleft ()
1178
1191
except IndexError :
1179
1192
# Can raise ConnectionFailure or CertificateError.
1180
- sock_info = self .connect ()
1193
+ sock_info = self .connect (all_credentials )
1181
1194
else :
1182
1195
if self ._perished (sock_info ):
1183
1196
sock_info = None
1197
+ sock_info .check_auth (all_credentials )
1184
1198
except Exception :
1185
1199
self ._socket_semaphore .release ()
1186
1200
with self .lock :
@@ -1193,16 +1207,14 @@ def _get_socket_no_auth(self):
1193
1207
1194
1208
return sock_info
1195
1209
1196
- def return_socket (self , sock_info , publish_checkin = True ):
1210
+ def return_socket (self , sock_info ):
1197
1211
"""Return the socket to the pool, or if it's closed discard it.
1198
1212
1199
1213
:Parameters:
1200
1214
- `sock_info`: The socket to check into the pool.
1201
- - `publish_checkin`: If False, a ConnectionCheckedInEvent will not
1202
- be published.
1203
1215
"""
1204
1216
listeners = self .opts .event_listeners
1205
- if self .enabled_for_cmap and publish_checkin :
1217
+ if self .enabled_for_cmap :
1206
1218
listeners .publish_connection_checked_in (self .address , sock_info .id )
1207
1219
if self .pid != os .getpid ():
1208
1220
self .reset ()
0 commit comments