Skip to content

Commit dd068f4

Browse files
terrycainjettify
authored andcommitted
SSL Support (plus mysql_clear_password plugin for RDS) (#280)
* Added SSL Support again * Issue #265 - _process_auth implementation * Added cleartext plugin test
1 parent ac6267d commit dd068f4

File tree

2 files changed

+134
-10
lines changed

2 files changed

+134
-10
lines changed

aiomysql/connection.py

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def connect(host="localhost", user=None, password="",
5656
client_flag=0, cursorclass=Cursor, init_command=None,
5757
connect_timeout=None, read_default_group=None,
5858
no_delay=None, autocommit=False, echo=False,
59-
local_infile=False, loop=None):
59+
local_infile=False, loop=None, ssl=None, auth_plugin=''):
6060
"""See connections.Connection.__init__() for information about
6161
defaults."""
6262
coro = _connect(host=host, user=user, password=password, db=db,
@@ -68,7 +68,8 @@ def connect(host="localhost", user=None, password="",
6868
connect_timeout=connect_timeout,
6969
read_default_group=read_default_group,
7070
no_delay=no_delay, autocommit=autocommit, echo=echo,
71-
local_infile=local_infile, loop=loop)
71+
local_infile=local_infile, loop=loop, ssl=ssl,
72+
auth_plugin=auth_plugin)
7273
return _ConnectionContextManager(coro)
7374

7475

@@ -93,7 +94,7 @@ def __init__(self, host="localhost", user=None, password="",
9394
client_flag=0, cursorclass=Cursor, init_command=None,
9495
connect_timeout=None, read_default_group=None,
9596
no_delay=None, autocommit=False, echo=False,
96-
local_infile=False, loop=None):
97+
local_infile=False, loop=None, ssl=None, auth_plugin=''):
9798
"""
9899
Establish a connection to the MySQL database. Accepts several
99100
arguments:
@@ -164,6 +165,9 @@ def __init__(self, host="localhost", user=None, password="",
164165
self._no_delay = no_delay
165166
self._echo = echo
166167
self._last_usage = self._loop.time()
168+
self._client_auth_plugin = auth_plugin
169+
self._server_auth_plugin = ""
170+
self._auth_plugin_used = ""
167171

168172
self._unix_socket = unix_socket
169173
if charset:
@@ -176,6 +180,10 @@ def __init__(self, host="localhost", user=None, password="",
176180
if use_unicode is not None:
177181
self.use_unicode = use_unicode
178182

183+
self._ssl_context = ssl
184+
if ssl:
185+
client_flag |= CLIENT.SSL
186+
179187
self._encoding = charset_by_name(self._charset).encoding
180188

181189
if local_infile:
@@ -209,8 +217,6 @@ def __init__(self, host="localhost", user=None, password="",
209217
# user
210218
self._close_reason = None
211219

212-
self._auth_plugin_name = ""
213-
214220
@property
215221
def host(self):
216222
"""MySQL server IP address or name"""
@@ -663,6 +669,31 @@ def _request_authentication(self):
663669
if self.user is None:
664670
raise ValueError("Did not specify a username")
665671

672+
if self._ssl_context:
673+
# capablities, max packet, charset
674+
data = struct.pack('<IIB', self.client_flag, 16777216, 33)
675+
data += b'\x00' * (32 - len(data))
676+
677+
self.write_packet(data)
678+
679+
# Stop sending events to data_received
680+
self._writer.transport.pause_reading()
681+
682+
# Get the raw socket from the transport
683+
raw_sock = self._writer.transport.get_extra_info('socket',
684+
default=None)
685+
if raw_sock is None:
686+
raise RuntimeError("Transport does not expose socket instance")
687+
688+
# MySQL expects TLS negotiation to happen in the middle of a
689+
# TCP connection not at start. Passing in a socket to
690+
# open_connection will cause it to negotiate TLS on an existing
691+
# connection not initiate a new one.
692+
self._reader, self._writer = yield from asyncio.open_connection(
693+
sock=raw_sock, ssl=self._ssl_context, loop=self._loop,
694+
server_hostname=self._host
695+
)
696+
666697
charset_id = charset_by_name(self.charset).id
667698
if isinstance(self.user, str):
668699
_user = self.user.encode(self.encoding)
@@ -673,8 +704,16 @@ def _request_authentication(self):
673704
data = data_init + _user + b'\0'
674705

675706
authresp = b''
676-
if self._auth_plugin_name in ('', 'mysql_native_password'):
707+
708+
auth_plugin = self._client_auth_plugin
709+
if not self._client_auth_plugin:
710+
# Contains the auth plugin from handshake
711+
auth_plugin = self._server_auth_plugin
712+
713+
if auth_plugin in ('', 'mysql_native_password'):
677714
authresp = _scramble(self._password.encode('latin1'), self.salt)
715+
elif auth_plugin in ('', 'mysql_clear_password'):
716+
authresp = self._password.encode('latin1') + b'\0'
678717

679718
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
680719
data += lenenc_int(len(authresp)) + authresp
@@ -693,11 +732,13 @@ def _request_authentication(self):
693732
data += db + b'\0'
694733

695734
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
696-
name = self._auth_plugin_name
735+
name = auth_plugin
697736
if isinstance(name, str):
698737
name = name.encode('ascii')
699738
data += name + b'\0'
700739

740+
self._auth_plugin_used = auth_plugin
741+
701742
self.write_packet(data)
702743
auth_packet = yield from self._read_packet()
703744

@@ -710,14 +751,45 @@ def _request_authentication(self):
710751
plugin_name = auth_packet.read_string()
711752
if (self.server_capabilities & CLIENT.PLUGIN_AUTH and
712753
plugin_name is not None):
713-
auth_packet = self._process_auth(plugin_name, auth_packet)
754+
auth_packet = yield from self._process_auth(
755+
plugin_name, auth_packet)
714756
else:
715757
# send legacy handshake
716758
data = _scramble_323(self._password.encode('latin1'),
717759
self.salt) + b'\0'
718760
self.write_packet(data)
719761
auth_packet = yield from self._read_packet()
720762

763+
@asyncio.coroutine
764+
def _process_auth(self, plugin_name, auth_packet):
765+
if plugin_name == b"mysql_native_password":
766+
# https://dev.mysql.com/doc/internals/en/
767+
# secure-password-authentication.html#packet-Authentication::
768+
# Native41
769+
data = _scramble(self._password.encode('latin1'),
770+
auth_packet.read_all())
771+
elif plugin_name == b"mysql_old_password":
772+
# https://dev.mysql.com/doc/internals/en/
773+
# old-password-authentication.html
774+
data = _scramble_323(self._password.encode('latin1'),
775+
auth_packet.read_all()) + b'\0'
776+
elif plugin_name == b"mysql_clear_password":
777+
# https://dev.mysql.com/doc/internals/en/
778+
# clear-text-authentication.html
779+
data = self._password.encode('latin1') + b'\0'
780+
else:
781+
raise OperationalError(
782+
2059, "Authentication plugin '%s' not configured" % plugin_name
783+
)
784+
785+
self.write_packet(data)
786+
pkt = yield from self._read_packet()
787+
pkt.check_error()
788+
789+
self._auth_plugin_used = plugin_name
790+
791+
return pkt
792+
721793
# _mysql support
722794
def thread_id(self):
723795
return self.server_thread_id[0]
@@ -786,9 +858,9 @@ def _get_server_information(self):
786858
server_end = data.find(b'\0', i)
787859
if server_end < 0: # pragma: no cover - very specific upstream bug
788860
# not found \0 and last field so take it all
789-
self._auth_plugin_name = data[i:].decode('latin1')
861+
self._server_auth_plugin = data[i:].decode('latin1')
790862
else:
791-
self._auth_plugin_name = data[i:server_end].decode('latin1')
863+
self._server_auth_plugin = data[i:server_end].decode('latin1')
792864

793865
def get_transaction_status(self):
794866
return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_IN_TRANS)

tests/test_ssl.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from aiomysql import create_pool
2+
3+
import pytest
4+
5+
6+
@pytest.mark.run_loop
7+
async def test_tls_connect(mysql_server, loop):
8+
async with create_pool(**mysql_server['conn_params'],
9+
loop=loop) as pool:
10+
async with pool.get() as conn:
11+
async with conn.cursor() as cur:
12+
# Run simple command
13+
await cur.execute("SHOW DATABASES;")
14+
value = await cur.fetchall()
15+
16+
values = [item[0] for item in value]
17+
# Spot check the answers, we should at least have mysql
18+
# and information_schema
19+
assert 'mysql' in values, \
20+
'Could not find the "mysql" table'
21+
assert 'information_schema' in values, \
22+
'Could not find the "mysql" table'
23+
24+
# Check TLS variables
25+
await cur.execute("SHOW STATUS LIKE '%Ssl_version%';")
26+
value = await cur.fetchone()
27+
28+
# The context has TLS
29+
assert value[1].startswith('TLS'), \
30+
'Not connected to the database with TLS'
31+
32+
33+
# MySQL will get you to renegotiate if sent a cleartext password
34+
@pytest.mark.run_loop
35+
async def test_auth_plugin_renegotiation(mysql_server, loop):
36+
async with create_pool(**mysql_server['conn_params'],
37+
auth_plugin='mysql_clear_password',
38+
loop=loop) as pool:
39+
async with pool.get() as conn:
40+
async with conn.cursor() as cur:
41+
# Run simple command
42+
await cur.execute("SHOW DATABASES;")
43+
value = await cur.fetchall()
44+
45+
assert len(value), 'No databases found'
46+
47+
assert conn._client_auth_plugin == 'mysql_clear_password', \
48+
'Client did not try clear password auth'
49+
assert conn._server_auth_plugin == 'mysql_native_password', \
50+
'Server did not ask for native auth'
51+
assert conn._auth_plugin_used == b'mysql_native_password', \
52+
'Client did not renegotiate with native auth'

0 commit comments

Comments
 (0)