Skip to content

Commit 884802c

Browse files
authored
Merge pull request #312 from terrycain/sha256
MySQL 8 Compatibility and SHA256 authentication plugin support
2 parents 425f81e + 332a249 commit 884802c

File tree

6 files changed

+357
-36
lines changed

6 files changed

+357
-36
lines changed

aiomysql/connection.py

Lines changed: 178 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
# from aiomysql.utils import _convert_to_str
4242
from .cursors import Cursor
4343
from .utils import _ConnectionContextManager, _ContextManager
44-
# from .log import logger
44+
from .log import logger
4545

4646

4747
DEFAULT_USER = getpass.getuser()
@@ -55,7 +55,7 @@ def connect(host="localhost", user=None, password="",
5555
connect_timeout=None, read_default_group=None,
5656
no_delay=None, autocommit=False, echo=False,
5757
local_infile=False, loop=None, ssl=None, auth_plugin='',
58-
program_name=''):
58+
program_name='', server_public_key=None):
5959
"""See connections.Connection.__init__() for information about
6060
defaults."""
6161
coro = _connect(host=host, user=user, password=password, db=db,
@@ -93,7 +93,7 @@ def __init__(self, host="localhost", user=None, password="",
9393
connect_timeout=None, read_default_group=None,
9494
no_delay=None, autocommit=False, echo=False,
9595
local_infile=False, loop=None, ssl=None, auth_plugin='',
96-
program_name=''):
96+
program_name='', server_public_key=None):
9797
"""
9898
Establish a connection to the MySQL database. Accepts several
9999
arguments:
@@ -134,6 +134,8 @@ def __init__(self, host="localhost", user=None, password="",
134134
(default: Server Default)
135135
:param program_name: Program name string to provide when
136136
handshaking with MySQL. (default: sys.argv[0])
137+
:param server_public_key: SHA256 authentication plugin public
138+
key value.
137139
:param loop: asyncio loop
138140
"""
139141
self._loop = loop or asyncio.get_event_loop()
@@ -174,6 +176,8 @@ def __init__(self, host="localhost", user=None, password="",
174176
self._client_auth_plugin = auth_plugin
175177
self._server_auth_plugin = ""
176178
self._auth_plugin_used = ""
179+
self.server_public_key = server_public_key
180+
self.salt = None
177181

178182
# TODO somehow import version from __init__.py
179183
self._connect_attrs = {
@@ -712,6 +716,20 @@ async def _request_authentication(self):
712716
if auth_plugin in ('', 'mysql_native_password'):
713717
authresp = _auth.scramble_native_password(
714718
self._password.encode('latin1'), self.salt)
719+
elif auth_plugin == 'caching_sha2_password':
720+
if self._password:
721+
authresp = _auth.scramble_caching_sha2(
722+
self._password.encode('latin1'), self.salt
723+
)
724+
# Else: empty password
725+
elif auth_plugin == 'sha256_password':
726+
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
727+
authresp = self._password.encode('latin1') + b'\0'
728+
elif self._password:
729+
authresp = b'\1' # request public key
730+
else:
731+
authresp = b'\0' # empty password
732+
715733
elif auth_plugin in ('', 'mysql_clear_password'):
716734
authresp = self._password.encode('latin1') + b'\0'
717735

@@ -768,35 +786,174 @@ async def _request_authentication(self):
768786
auth_packet.read_all()) + b'\0'
769787
self.write_packet(data)
770788
await self._read_packet()
789+
elif auth_packet.is_extra_auth_data():
790+
if auth_plugin == "caching_sha2_password":
791+
await self.caching_sha2_password_auth(auth_packet)
792+
elif auth_plugin == "sha256_password":
793+
await self.sha256_password_auth(auth_packet)
794+
else:
795+
raise OperationalError("Received extra packet "
796+
"for auth method %r", auth_plugin)
771797

772798
async def _process_auth(self, plugin_name, auth_packet):
773-
if plugin_name == b"mysql_native_password":
774-
# https://dev.mysql.com/doc/internals/en/
775-
# secure-password-authentication.html#packet-Authentication::
776-
# Native41
777-
data = _auth.scramble_native_password(
778-
self._password.encode('latin1'),
779-
auth_packet.read_all())
780-
elif plugin_name == b"mysql_old_password":
781-
# https://dev.mysql.com/doc/internals/en/
782-
# old-password-authentication.html
783-
data = _auth.scramble_old_password(self._password.encode('latin1'),
784-
auth_packet.read_all()) + b'\0'
785-
elif plugin_name == b"mysql_clear_password":
786-
# https://dev.mysql.com/doc/internals/en/
787-
# clear-text-authentication.html
788-
data = self._password.encode('latin1') + b'\0'
799+
# These auth plugins do their own packet handling
800+
if plugin_name == b"caching_sha2_password":
801+
await self.caching_sha2_password_auth(auth_packet)
802+
self._auth_plugin_used = plugin_name.decode()
803+
elif plugin_name == b"sha256_password":
804+
await self.sha256_password_auth(auth_packet)
805+
self._auth_plugin_used = plugin_name.decode()
789806
else:
807+
808+
if plugin_name == b"mysql_native_password":
809+
# https://dev.mysql.com/doc/internals/en/
810+
# secure-password-authentication.html#packet-Authentication::
811+
# Native41
812+
data = _auth.scramble_native_password(
813+
self._password.encode('latin1'),
814+
auth_packet.read_all())
815+
elif plugin_name == b"mysql_old_password":
816+
# https://dev.mysql.com/doc/internals/en/
817+
# old-password-authentication.html
818+
data = _auth.scramble_old_password(
819+
self._password.encode('latin1'),
820+
auth_packet.read_all()
821+
) + b'\0'
822+
elif plugin_name == b"mysql_clear_password":
823+
# https://dev.mysql.com/doc/internals/en/
824+
# clear-text-authentication.html
825+
data = self._password.encode('latin1') + b'\0'
826+
else:
827+
raise OperationalError(
828+
2059, "Authentication plugin '{0}'"
829+
" not configured".format(plugin_name)
830+
)
831+
832+
self.write_packet(data)
833+
pkt = await self._read_packet()
834+
pkt.check_error()
835+
836+
self._auth_plugin_used = plugin_name.decode()
837+
838+
return pkt
839+
840+
async def caching_sha2_password_auth(self, pkt):
841+
# No password fast path
842+
if not self._password:
843+
self.write_packet(b'')
844+
pkt = await self._read_packet()
845+
pkt.check_error()
846+
return pkt
847+
848+
if pkt.is_auth_switch_request():
849+
# Try from fast auth
850+
logger.debug("caching sha2: Trying fast path")
851+
self.salt = pkt.read_all()
852+
scrambled = _auth.scramble_caching_sha2(
853+
self._password.encode('latin1'), self.salt
854+
)
855+
856+
self.write_packet(scrambled)
857+
pkt = await self._read_packet()
858+
pkt.check_error()
859+
860+
# else: fast auth is tried in initial handshake
861+
862+
if not pkt.is_extra_auth_data():
790863
raise OperationalError(
791-
2059, "Authentication plugin '%s' not configured" % plugin_name
864+
"caching sha2: Unknown packet "
865+
"for fast auth: {0}".format(pkt._data[:1])
792866
)
793867

868+
# magic numbers:
869+
# 2 - request public key
870+
# 3 - fast auth succeeded
871+
# 4 - need full auth
872+
873+
pkt.advance(1)
874+
n = pkt.read_uint8()
875+
876+
if n == 3:
877+
logger.debug("caching sha2: succeeded by fast path.")
878+
pkt = await self._read_packet()
879+
pkt.check_error() # pkt must be OK packet
880+
return pkt
881+
882+
if n != 4:
883+
raise OperationalError("caching sha2: Unknown "
884+
"result for fast auth: {0}".format(n))
885+
886+
logger.debug("caching sha2: Trying full auth...")
887+
888+
if self._ssl_context:
889+
logger.debug("caching sha2: Sending plain "
890+
"password via secure connection")
891+
self.write_packet(self._password.encode('latin1') + b'\0')
892+
pkt = await self._read_packet()
893+
pkt.check_error()
894+
return pkt
895+
896+
if not self.server_public_key:
897+
self.write_packet(b'\x02')
898+
pkt = await self._read_packet() # Request public key
899+
pkt.check_error()
900+
901+
if not pkt.is_extra_auth_data():
902+
raise OperationalError(
903+
"caching sha2: Unknown packet "
904+
"for public key: {0}".format(pkt._data[:1])
905+
)
906+
907+
self.server_public_key = pkt._data[1:]
908+
logger.debug(self.server_public_key.decode('ascii'))
909+
910+
data = _auth.sha2_rsa_encrypt(
911+
self._password.encode('latin1'), self.salt,
912+
self.server_public_key
913+
)
794914
self.write_packet(data)
795915
pkt = await self._read_packet()
796916
pkt.check_error()
797917

798-
self._auth_plugin_used = plugin_name
918+
async def sha256_password_auth(self, pkt):
919+
if self._ssl_context:
920+
logger.debug("sha256: Sending plain password")
921+
data = self._password.encode('latin1') + b'\0'
922+
self.write_packet(data)
923+
pkt = await self._read_packet()
924+
pkt.check_error()
925+
return pkt
926+
927+
if pkt.is_auth_switch_request():
928+
self.salt = pkt.read_all()
929+
if not self.server_public_key and self._password:
930+
# Request server public key
931+
logger.debug("sha256: Requesting server public key")
932+
self.write_packet(b'\1')
933+
pkt = await self._read_packet()
934+
pkt.check_error()
935+
936+
if pkt.is_extra_auth_data():
937+
self.server_public_key = pkt._data[1:]
938+
logger.debug(
939+
"Received public key:\n",
940+
self.server_public_key.decode('ascii')
941+
)
942+
943+
if self._password:
944+
if not self.server_public_key:
945+
raise OperationalError("Couldn't receive server's public key")
946+
947+
data = _auth.sha2_rsa_encrypt(
948+
self._password.encode('latin1'), self.salt,
949+
self.server_public_key
950+
)
951+
else:
952+
data = b''
799953

954+
self.write_packet(data)
955+
pkt = await self._read_packet()
956+
pkt.check_error()
800957
return pkt
801958

802959
# _mysql support

docs/connection.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ Example::
4747
client_flag=0, cursorclass=Cursor, init_command=None,
4848
connect_timeout=None, read_default_group=None,
4949
no_delay=False, autocommit=False, echo=False,
50-
ssl=None, auth_plugin='', program_name='', loop=None)
50+
ssl=None, auth_plugin='', program_name='',
51+
server_public_key=None, loop=None)
5152

5253
A :ref:`coroutine <coroutine>` that connects to MySQL.
5354

@@ -89,6 +90,7 @@ Example::
8990
(default: Server Default)
9091
:param program_name: Program name string to provide when
9192
handshaking with MySQL. (default: sys.argv[0])
93+
:param server_public_key: SHA256 authenticaiton plugin public key value.
9294
:param loop: asyncio event loop instance or ``None`` for default one.
9395
:returns: :class:`Connection` instance.
9496

examples/example_ssl.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import asyncio
2+
import ssl
3+
import aiomysql
4+
5+
ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
6+
ctx.check_hostname = False
7+
ctx.load_verify_locations(cafile='../tests/ssl_resources/ssl/ca.pem')
8+
9+
10+
async def main():
11+
async with aiomysql.create_pool(
12+
host='localhost', port=3306, user='root',
13+
password='rootpw', ssl=ctx,
14+
auth_plugin='mysql_clear_password') as pool:
15+
16+
async with pool.get() as conn:
17+
async with conn.cursor() as cur:
18+
# Run simple command
19+
await cur.execute("SHOW DATABASES;")
20+
value = await cur.fetchall()
21+
22+
values = [item[0] for item in value]
23+
# Spot check the answers, we should at least have mysql
24+
# and information_schema
25+
assert 'mysql' in values, \
26+
'Could not find the "mysql" table'
27+
assert 'information_schema' in values, \
28+
'Could not find the "mysql" table'
29+
30+
# Check TLS variables
31+
await cur.execute("SHOW STATUS LIKE 'Ssl_version%';")
32+
value = await cur.fetchone()
33+
34+
# The context has TLS
35+
assert value[1].startswith('TLS'), \
36+
'Not connected to the database with TLS'
37+
38+
asyncio.get_event_loop().run_until_complete(main())

0 commit comments

Comments
 (0)