Skip to content

Commit 19ab680

Browse files
committed
Initial SHA256 Implementation
1 parent 4897f50 commit 19ab680

File tree

3 files changed

+165
-14
lines changed

3 files changed

+165
-14
lines changed

aiomysql/connection.py

Lines changed: 153 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
# from aiomysql.utils import _convert_to_str
4242
from .cursors import Cursor
4343
from .utils import _ConnectionContextManager, _ContextManager
44-
# from .log import logger
44+
from .log import logger
4545

4646

4747
DEFAULT_USER = getpass.getuser()
@@ -55,7 +55,7 @@ def connect(host="localhost", user=None, password="",
5555
connect_timeout=None, read_default_group=None,
5656
no_delay=None, autocommit=False, echo=False,
5757
local_infile=False, loop=None, ssl=None, auth_plugin='',
58-
program_name=''):
58+
program_name='', server_public_key=None):
5959
"""See connections.Connection.__init__() for information about
6060
defaults."""
6161
coro = _connect(host=host, user=user, password=password, db=db,
@@ -93,7 +93,7 @@ def __init__(self, host="localhost", user=None, password="",
9393
connect_timeout=None, read_default_group=None,
9494
no_delay=None, autocommit=False, echo=False,
9595
local_infile=False, loop=None, ssl=None, auth_plugin='',
96-
program_name=''):
96+
program_name='', server_public_key=None):
9797
"""
9898
Establish a connection to the MySQL database. Accepts several
9999
arguments:
@@ -134,6 +134,8 @@ def __init__(self, host="localhost", user=None, password="",
134134
(default: Server Default)
135135
:param program_name: Program name string to provide when
136136
handshaking with MySQL. (default: sys.argv[0])
137+
:param server_public_key: SHA256 authentication plugin public
138+
key value.
137139
:param loop: asyncio loop
138140
"""
139141
self._loop = loop or asyncio.get_event_loop()
@@ -174,6 +176,8 @@ def __init__(self, host="localhost", user=None, password="",
174176
self._client_auth_plugin = auth_plugin
175177
self._server_auth_plugin = ""
176178
self._auth_plugin_used = ""
179+
self.server_public_key = server_public_key
180+
self.salt = None
177181

178182
# TODO somehow import version from __init__.py
179183
self._connect_attrs = {
@@ -711,6 +715,20 @@ async def _request_authentication(self):
711715
if auth_plugin in ('', 'mysql_native_password'):
712716
authresp = _auth.scramble_native_password(
713717
self._password.encode('latin1'), self.salt)
718+
elif auth_plugin == 'caching_sha2_password':
719+
if self._password:
720+
authresp = _auth.scramble_caching_sha2(
721+
self._password.encode('latin1'), self.salt
722+
)
723+
# Else: empty password
724+
elif auth_plugin == 'sha256_password':
725+
if self._ssl_context and self.server_capabilities & CLIENT.SSL:
726+
authresp = self._password.encode('latin1') + b'\0'
727+
elif self._password:
728+
authresp = b'\1' # request public key
729+
else:
730+
authresp = b'\0' # empty password
731+
714732
elif auth_plugin in ('', 'mysql_clear_password'):
715733
authresp = self._password.encode('latin1') + b'\0'
716734

@@ -767,9 +785,21 @@ async def _request_authentication(self):
767785
auth_packet.read_all()) + b'\0'
768786
self.write_packet(data)
769787
await self._read_packet()
788+
elif auth_packet.is_extra_auth_data():
789+
if auth_plugin == "caching_sha2_password":
790+
await self.caching_sha2_password_auth(auth_packet)
791+
elif auth_plugin == "sha256_password":
792+
await self.sha256_password_auth(auth_packet)
793+
else:
794+
raise OperationalError("Received extra packet "
795+
"for auth method %r", auth_plugin)
770796

771797
async def _process_auth(self, plugin_name, auth_packet):
772-
if plugin_name == b"mysql_native_password":
798+
if plugin_name == b"caching_sha2_password":
799+
return self.caching_sha2_password_auth(auth_packet)
800+
elif plugin_name == b"sha256_password":
801+
return self.sha256_password_auth(auth_packet)
802+
elif plugin_name == b"mysql_native_password":
773803
# https://dev.mysql.com/doc/internals/en/
774804
# secure-password-authentication.html#packet-Authentication::
775805
# Native41
@@ -798,6 +828,125 @@ async def _process_auth(self, plugin_name, auth_packet):
798828

799829
return pkt
800830

831+
async def caching_sha2_password_auth(self, pkt):
832+
# No password fast path
833+
if not self._password:
834+
self.write_packet(b'')
835+
pkt = await self._read_packet()
836+
pkt.check_error()
837+
return pkt
838+
839+
if pkt.is_auth_switch_request():
840+
# Try from fast auth
841+
logger.debug("caching sha2: Trying fast path")
842+
self.salt = pkt.read_all()
843+
scrambled = _auth.scramble_caching_sha2(
844+
self._password.encode('latin1'), self.salt
845+
)
846+
847+
self.write_packet(scrambled)
848+
pkt = await self._read_packet()
849+
pkt.check_error()
850+
851+
# else: fast auth is tried in initial handshake
852+
853+
if not pkt.is_extra_auth_data():
854+
raise OperationalError(
855+
"caching sha2: Unknown packet "
856+
"for fast auth: {0}".format(pkt._data[:1])
857+
)
858+
859+
# magic numbers:
860+
# 2 - request public key
861+
# 3 - fast auth succeeded
862+
# 4 - need full auth
863+
864+
pkt.advance(1)
865+
n = pkt.read_uint8()
866+
867+
if n == 3:
868+
logger.debug("caching sha2: succeeded by fast path.")
869+
pkt = await self._read_packet()
870+
pkt.check_error() # pkt must be OK packet
871+
return pkt
872+
873+
if n != 4:
874+
raise OperationalError("caching sha2: Unknown "
875+
"result for fast auth: {0}".format(n))
876+
877+
logger.debug("caching sha2: Trying full auth...")
878+
879+
if self._ssl_context:
880+
logger.debug("caching sha2: Sending plain "
881+
"password via secure connection")
882+
self.write_packet(self._password.encode('latin1') + b'\0')
883+
pkt = await self._read_packet()
884+
pkt.check_error()
885+
return pkt
886+
887+
if not self.server_public_key:
888+
self.write_packet(b'\x02')
889+
pkt = await self._read_packet() # Request public key
890+
pkt.check_error()
891+
892+
if not pkt.is_extra_auth_data():
893+
raise OperationalError(
894+
"caching sha2: Unknown packet "
895+
"for public key: {0}".format(pkt._data[:1])
896+
)
897+
898+
self.server_public_key = pkt._data[1:]
899+
logger.debug(self.server_public_key.decode('ascii'))
900+
901+
data = _auth.sha2_rsa_encrypt(
902+
self._password.encode('latin1'), self.salt,
903+
self.server_public_key
904+
)
905+
self.write_packet(data)
906+
pkt = await self._read_packet()
907+
pkt.check_error()
908+
909+
async def sha256_password_auth(self, pkt):
910+
if self._ssl_context:
911+
logger.debug("sha256: Sending plain password")
912+
data = self._password.encode('latin1') + b'\0'
913+
self.write_packet(data)
914+
pkt = await self._read_packet()
915+
pkt.check_error()
916+
return pkt
917+
918+
if pkt.is_auth_switch_request():
919+
self.salt = pkt.read_all()
920+
if not self.server_public_key and self._password:
921+
# Request server public key
922+
logger.debug("sha256: Requesting server public key")
923+
self.write_packet(b'\1')
924+
pkt = await self._read_packet()
925+
pkt.check_error()
926+
927+
if pkt.is_extra_auth_data():
928+
self.server_public_key = pkt._data[1:]
929+
logger.debug(
930+
"Received public key:\n",
931+
self.server_public_key.decode('ascii')
932+
)
933+
934+
if self._password:
935+
if not self.server_public_key:
936+
raise OperationalError("Couldn't receive server's public key")
937+
938+
data = _auth.sha2_rsa_encrypt(
939+
self._password.encode('latin1'), self.salt,
940+
self.server_public_key
941+
)
942+
else:
943+
data = b''
944+
945+
self.write_packet(data)
946+
pkt = await self._read_packet()
947+
pkt.check_error()
948+
return pkt
949+
801950
# _mysql support
802951
def thread_id(self):
803952
return self.server_thread_id[0]

docs/connection.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ Example::
4747
client_flag=0, cursorclass=Cursor, init_command=None,
4848
connect_timeout=None, read_default_group=None,
4949
no_delay=False, autocommit=False, echo=False,
50-
ssl=None, auth_plugin='', program_name='', loop=None)
50+
ssl=None, auth_plugin='', program_name='',
51+
server_public_key=None, loop=None)
5152

5253
A :ref:`coroutine <coroutine>` that connects to MySQL.
5354

@@ -89,6 +90,7 @@ Example::
8990
(default: Server Default)
9091
:param program_name: Program name string to provide when
9192
handshaking with MySQL. (default: sys.argv[0])
93+
:param server_public_key: SHA256 authenticaiton plugin public key value.
9294
:param loop: asyncio event loop instance or ``None`` for default one.
9395
:returns: :class:`Connection` instance.
9496

tests/conftest.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ def pytest_generate_tests(metafunc):
3535
loop_type = ['asyncio', 'uvloop'] if uvloop else ['asyncio']
3636
metafunc.parametrize("loop_type", loop_type)
3737

38-
# if 'mysql_tag' in metafunc.fixturenames:
39-
# tags = set(metafunc.config.option.mysql_tag)
40-
# if not tags:
41-
# tags = ['5.7']
42-
# elif 'all' in tags:
43-
# tags = ['5.6', '5.7', '8.0']
44-
# else:
45-
# tags = list(tags)
46-
# metafunc.parametrize("mysql_tag", tags, scope='session')
38+
if 'mysql_tag' in metafunc.fixturenames:
39+
# tags = set(metafunc.config.option.mysql_tag)
40+
# if not tags:
41+
# tags = ['5.7']
42+
# elif 'all' in tags:
43+
# tags = ['5.6', '5.7', '8.0']
44+
# else:
45+
# tags = list(tags)
46+
metafunc.parametrize("mysql_tag", ['5.6', '8.0'], scope='session')
4747

4848

4949
# This is here unless someone fixes the generate_tests bit

0 commit comments

Comments
 (0)