Skip to content

Commit 0eace78

Browse files
committed
PYTHON-2158 Support speculative authentication attempts in connection handshake
1 parent 45a7963 commit 0eace78

File tree

6 files changed

+177
-39
lines changed

6 files changed

+177
-39
lines changed

pymongo/auth.py

Lines changed: 86 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -254,9 +254,22 @@ def _parse_scram_response(response):
254254
return dict(item.split(b"=", 1) for item in response.split(b","))
255255

256256

257+
def _authenticate_scram_start(credentials, mechanism):
258+
username = credentials.username
259+
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
260+
nonce = standard_b64encode(os.urandom(32))
261+
first_bare = b"n=" + user + b",r=" + nonce
262+
263+
cmd = SON([('saslStart', 1),
264+
('mechanism', mechanism),
265+
('payload', Binary(b"n,," + first_bare)),
266+
('autoAuthorize', 1),
267+
('options', {'skipEmptyExchange': True})])
268+
return nonce, first_bare, cmd
269+
270+
257271
def _authenticate_scram(credentials, sock_info, mechanism):
258272
"""Authenticate using SCRAM."""
259-
260273
username = credentials.username
261274
if mechanism == 'SCRAM-SHA-256':
262275
digest = "sha256"
@@ -272,16 +285,14 @@ def _authenticate_scram(credentials, sock_info, mechanism):
272285
# Make local
273286
_hmac = hmac.HMAC
274287

275-
user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C")
276-
nonce = standard_b64encode(os.urandom(32))
277-
first_bare = b"n=" + user + b",r=" + nonce
278-
279-
cmd = SON([('saslStart', 1),
280-
('mechanism', mechanism),
281-
('payload', Binary(b"n,," + first_bare)),
282-
('autoAuthorize', 1),
283-
('options', {'skipEmptyExchange': True})])
284-
res = sock_info.command(source, cmd)
288+
ctx = sock_info.auth_ctx.get(credentials)
289+
if ctx and ctx.speculate_succeeded():
290+
nonce, first_bare = ctx.scram_data
291+
res = ctx.speculative_authenticate
292+
else:
293+
nonce, first_bare, cmd = _authenticate_scram_start(
294+
credentials, mechanism)
295+
res = sock_info.command(source, cmd)
285296

286297
server_first = res['payload']
287298
parsed = _parse_scram_response(server_first)
@@ -516,15 +527,17 @@ def _authenticate_cram_md5(credentials, sock_info):
516527
def _authenticate_x509(credentials, sock_info):
517528
"""Authenticate using MONGODB-X509.
518529
"""
519-
query = SON([('authenticate', 1),
520-
('mechanism', 'MONGODB-X509')])
521-
if credentials.username is not None:
522-
query['user'] = credentials.username
523-
elif sock_info.max_wire_version < 5:
530+
ctx = sock_info.auth_ctx.get(credentials)
531+
if ctx and ctx.speculate_succeeded():
532+
# MONGODB-X509 is done after the speculative auth step.
533+
return
534+
535+
cmd = _X509Context(credentials).speculate_command()
536+
if credentials.username is None and sock_info.max_wire_version < 5:
524537
raise ConfigurationError(
525538
"A username is required for MONGODB-X509 authentication "
526539
"when connected to MongoDB versions older than 3.4.")
527-
sock_info.command('$external', query)
540+
sock_info.command('$external', cmd)
528541

529542

530543
def _authenticate_aws(credentials, sock_info):
@@ -597,6 +610,62 @@ def _authenticate_default(credentials, sock_info):
597610
}
598611

599612

613+
class _AuthContext(object):
614+
def __init__(self, credentials):
615+
self.credentials = credentials
616+
self.speculative_authenticate = None
617+
618+
@staticmethod
619+
def from_credentials(creds):
620+
spec_cls = _SPECULATIVE_AUTH_MAP.get(creds.mechanism)
621+
if spec_cls:
622+
return spec_cls(creds)
623+
return None
624+
625+
def speculate_command(self):
626+
raise NotImplementedError
627+
628+
def parse_response(self, ismaster):
629+
self.speculative_authenticate = ismaster.speculative_authenticate
630+
631+
def speculate_succeeded(self):
632+
return bool(self.speculative_authenticate)
633+
634+
635+
class _ScramContext(_AuthContext):
636+
def __init__(self, credentials, mechanism):
637+
super(_ScramContext, self).__init__(credentials)
638+
self.scram_data = None
639+
self.mechanism = mechanism
640+
641+
def speculate_command(self):
642+
nonce, first_bare, cmd = _authenticate_scram_start(
643+
self.credentials, self.mechanism)
644+
# The 'db' field is included only on the speculative command.
645+
cmd['db'] = self.credentials.source
646+
# Save for later use.
647+
self.scram_data = (nonce, first_bare)
648+
return cmd
649+
650+
651+
class _X509Context(_AuthContext):
652+
def speculate_command(self):
653+
cmd = SON([('authenticate', 1),
654+
('mechanism', 'MONGODB-X509')])
655+
if self.credentials.username is not None:
656+
cmd['user'] = self.credentials.username
657+
return cmd
658+
659+
660+
_SPECULATIVE_AUTH_MAP = {
661+
'MONGODB-X509': _X509Context,
662+
'SCRAM-SHA-1': functools.partial(_ScramContext, mechanism='SCRAM-SHA-1'),
663+
'SCRAM-SHA-256': functools.partial(_ScramContext,
664+
mechanism='SCRAM-SHA-256'),
665+
'DEFAULT': functools.partial(_ScramContext, mechanism='SCRAM-SHA-256'),
666+
}
667+
668+
600669
def authenticate(credentials, sock_info):
601670
"""Authenticate sock_info."""
602671
mechanism = credentials.mechanism

pymongo/ismaster.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,11 @@ def sasl_supported_mechs(self):
169169
"""
170170
return self._doc.get('saslSupportedMechs', [])
171171

172+
@property
173+
def speculative_authenticate(self):
174+
"""The speculativeAuthenticate field."""
175+
return self._doc.get('speculativeAuthenticate')
176+
172177
@property
173178
def topology_version(self):
174179
return self._doc.get('topologyVersion')

pymongo/pool.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,6 +468,15 @@ def _negotiate_creds(all_credentials):
468468
return None
469469

470470

471+
def _speculative_context(all_credentials):
472+
"""Return the _AuthContext to use for speculative auth, if any.
473+
"""
474+
if all_credentials and len(all_credentials) == 1:
475+
creds = next(itervalues(all_credentials))
476+
return auth._AuthContext.from_credentials(creds)
477+
return None
478+
479+
471480
class SocketInfo(object):
472481
"""Store a socket with some metadata.
473482
@@ -501,6 +510,7 @@ def __init__(self, sock, pool, address, id):
501510
# Support for mechanism negotiation on the initial handshake.
502511
# Maps credential to saslSupportedMechs.
503512
self.negotiated_mechanisms = {}
513+
self.auth_ctx = {}
504514

505515
# The pool's generation changes with each reset() so we can close
506516
# sockets created before the last reset.
@@ -522,6 +532,9 @@ def ismaster(self, metadata, cluster_time, all_credentials=None):
522532
creds = _negotiate_creds(all_credentials)
523533
if creds:
524534
cmd['saslSupportedMechs'] = creds.source + '.' + creds.username
535+
auth_ctx = _speculative_context(all_credentials)
536+
if auth_ctx:
537+
cmd['speculativeAuthenticate'] = auth_ctx.speculate_command()
525538

526539
ismaster = IsMaster(self.command('admin', cmd, publish_events=False))
527540
self.is_writable = ismaster.is_writable
@@ -541,6 +554,10 @@ def ismaster(self, metadata, cluster_time, all_credentials=None):
541554
self.op_msg_enabled = ismaster.max_wire_version >= 6
542555
if creds:
543556
self.negotiated_mechanisms[creds] = ismaster.sasl_supported_mechs
557+
if auth_ctx:
558+
auth_ctx.parse_response(ismaster)
559+
if auth_ctx.speculate_succeeded():
560+
self.auth_ctx[auth_ctx.credentials] = auth_ctx
544561
return ismaster
545562

546563
def command(self, dbname, spec, slave_ok=False,
@@ -743,6 +760,7 @@ def authenticate(self, credentials):
743760
self.authset.add(credentials)
744761
# negotiated_mechanisms are no longer needed.
745762
self.negotiated_mechanisms.pop(credentials, None)
763+
self.auth_ctx.pop(credentials, None)
746764

747765
def validate_session(self, client, session):
748766
"""Validate this session before use with client.

test/test_auth.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -418,18 +418,20 @@ def test_scram_skip_empty_exchange(self):
418418
client = rs_or_single_client_noauth(
419419
username='sha256', password='pwd', authSource='testscram',
420420
event_listeners=[listener])
421-
client.admin.command('isMaster')
421+
client.testscram.command('dbstats')
422422

423-
# Assert we sent the skipEmptyExchange option.
424-
first_event = listener.results['started'][0]
425-
self.assertEqual(first_event.command_name, 'saslStart')
426-
self.assertEqual(
427-
first_event.command['options'], {'skipEmptyExchange': True})
423+
if client_context.version < (4, 4, -1):
424+
# Assert we sent the skipEmptyExchange option.
425+
first_event = listener.results['started'][0]
426+
self.assertEqual(first_event.command_name, 'saslStart')
427+
self.assertEqual(
428+
first_event.command['options'], {'skipEmptyExchange': True})
428429

429430
# Assert the third exchange was skipped on servers that support it.
431+
# Note that the first exchange occurs on the connection handshake.
430432
started = listener.started_command_names()
431-
if client_context.version.at_least(4, 3, 3):
432-
self.assertEqual(started, ['saslStart', 'saslContinue'])
433+
if client_context.version.at_least(4, 4, -1):
434+
self.assertEqual(started, ['saslContinue'])
433435
else:
434436
self.assertEqual(
435437
started, ['saslStart', 'saslContinue', 'saslContinue'])
@@ -578,8 +580,13 @@ def test_scram(self):
578580
'mongodb://both:pwd@%s:%d/testscram' % (host, port),
579581
event_listeners=[self.listener])
580582
client.testscram.command('dbstats')
581-
started = self.listener.results['started'][0]
582-
self.assertEqual(started.command.get('mechanism'), 'SCRAM-SHA-256')
583+
if client_context.version.at_least(4, 4, -1):
584+
# Speculative authentication in 4.4+ sends saslStart with the
585+
# handshake.
586+
self.assertEqual(self.listener.results['started'], [])
587+
else:
588+
started = self.listener.results['started'][0]
589+
self.assertEqual(started.command.get('mechanism'), 'SCRAM-SHA-256')
583590

584591
client = rs_or_single_client_noauth(
585592
'mongodb://both:pwd@%s:%d/testscram?authMechanism=SCRAM-SHA-1'

test/test_database.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@
5151
SkipTest,
5252
unittest,
5353
IntegrationTest)
54-
from test.utils import (ignore_deprecations,
54+
from test.utils import (EventListener,
55+
ignore_deprecations,
5556
remove_all_users,
5657
rs_or_single_client_noauth,
5758
rs_or_single_client,
@@ -677,14 +678,6 @@ def test_authenticate_multiple(self):
677678
admin_db_auth = self.client.admin
678679
users_db_auth = self.client.pymongo_test
679680

680-
# Non-root client.
681-
client = rs_or_single_client_noauth()
682-
admin_db = client.admin
683-
users_db = client.pymongo_test
684-
other_db = client.pymongo_test1
685-
686-
self.assertRaises(OperationFailure, users_db.test.find_one)
687-
688681
admin_db_auth.add_user(
689682
'ro-admin',
690683
'pass',
@@ -695,24 +688,60 @@ def test_authenticate_multiple(self):
695688
'user', 'pass', roles=["userAdmin", "readWrite"])
696689
self.addCleanup(remove_all_users, users_db_auth)
697690

691+
# Non-root client.
692+
listener = EventListener()
693+
client = rs_or_single_client_noauth(event_listeners=[listener])
694+
admin_db = client.admin
695+
users_db = client.pymongo_test
696+
other_db = client.pymongo_test1
697+
698+
self.assertRaises(OperationFailure, users_db.test.find_one)
699+
self.assertEqual(listener.started_command_names(), ['find'])
700+
listener.reset()
701+
698702
# Regular user should be able to query its own db, but
699703
# no other.
700704
users_db.authenticate('user', 'pass')
705+
if client_context.version.at_least(3, 0):
706+
self.assertEqual(listener.started_command_names()[0], 'saslStart')
707+
else:
708+
self.assertEqual(listener.started_command_names()[0], 'getnonce')
709+
701710
self.assertEqual(0, users_db.test.count_documents({}))
702711
self.assertRaises(OperationFailure, other_db.test.find_one)
703712

713+
listener.reset()
704714
# Admin read-only user should be able to query any db,
705715
# but not write.
706716
admin_db.authenticate('ro-admin', 'pass')
717+
if client_context.version.at_least(3, 0):
718+
self.assertEqual(listener.started_command_names()[0], 'saslStart')
719+
else:
720+
self.assertEqual(listener.started_command_names()[0], 'getnonce')
707721
self.assertEqual(None, other_db.test.find_one())
708722
self.assertRaises(OperationFailure,
709723
other_db.test.insert_one, {})
710724

711725
# Close all sockets.
712726
client.close()
713727

728+
listener.reset()
714729
# We should still be able to write to the regular user's db.
715730
self.assertTrue(users_db.test.delete_many({}))
731+
names = listener.started_command_names()
732+
if client_context.version.at_least(4, 4, -1):
733+
# No speculation with multiple users (but we do skipEmptyExchange).
734+
self.assertEqual(
735+
names, ['saslStart', 'saslContinue', 'saslStart',
736+
'saslContinue', 'delete'])
737+
elif client_context.version.at_least(3, 0):
738+
self.assertEqual(
739+
names, ['saslStart', 'saslContinue', 'saslContinue',
740+
'saslStart', 'saslContinue', 'saslContinue', 'delete'])
741+
else:
742+
self.assertEqual(
743+
names, ['getnonce', 'authenticate',
744+
'getnonce', 'authenticate', 'delete'])
716745

717746
# And read from other dbs...
718747
self.assertEqual(0, other_db.test.count_documents({}))

test/test_ssl.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@
3939
SkipTest,
4040
unittest,
4141
HAVE_IPADDRESS)
42-
from test.utils import (remove_all_users,
42+
from test.utils import (EventListener,
4343
cat_files,
44-
connected)
44+
connected,
45+
remove_all_users)
46+
4547

4648
_HAVE_PYOPENSSL = False
4749
try:
@@ -582,16 +584,24 @@ def test_mongodb_x509_auth(self):
582584

583585
self.assertRaises(OperationFailure, noauth.pymongo_test.test.count)
584586

587+
listener = EventListener()
585588
auth = MongoClient(
586589
client_context.pair,
587590
authMechanism='MONGODB-X509',
588591
ssl=True,
589592
ssl_cert_reqs=ssl.CERT_NONE,
590-
ssl_certfile=CLIENT_PEM)
593+
ssl_certfile=CLIENT_PEM,
594+
event_listeners=[listener])
591595

592596
if client_context.version.at_least(3, 3, 12):
593597
# No error
594598
auth.pymongo_test.test.find_one()
599+
names = listener.started_command_names()
600+
if client_context.version.at_least(4, 4, -1):
601+
# Speculative auth skips the authenticate command.
602+
self.assertEqual(names, ['find'])
603+
else:
604+
self.assertEqual(names, ['authenticate', 'find'])
595605
else:
596606
# Should require a username
597607
with self.assertRaises(ConfigurationError):

0 commit comments

Comments
 (0)