Skip to content

Commit 4897f50

Browse files
authored
Merge pull request #309 from terrycain/pymysql_302
PyMySQL 0.9 Update
2 parents 0f9a280 + 2497aec commit 4897f50

File tree

6 files changed

+68
-23
lines changed

6 files changed

+68
-23
lines changed

aiomysql/connection.py

Lines changed: 50 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,8 @@
2626
ProgrammingError)
2727

2828
from pymysql.connections import TEXT_TYPES, MAX_PACKET_LEN, DEFAULT_CHARSET
29-
# from pymysql.connections import dump_packet
30-
from pymysql.connections import _scramble
31-
from pymysql.connections import _scramble_323
29+
from pymysql.connections import _auth
30+
3231
from pymysql.connections import pack_int24
3332

3433
from pymysql.connections import MysqlPacket
@@ -44,6 +43,7 @@
4443
from .utils import _ConnectionContextManager, _ContextManager
4544
# from .log import logger
4645

46+
4747
DEFAULT_USER = getpass.getuser()
4848

4949

@@ -54,7 +54,8 @@ def connect(host="localhost", user=None, password="",
5454
client_flag=0, cursorclass=Cursor, init_command=None,
5555
connect_timeout=None, read_default_group=None,
5656
no_delay=None, autocommit=False, echo=False,
57-
local_infile=False, loop=None, ssl=None, auth_plugin=''):
57+
local_infile=False, loop=None, ssl=None, auth_plugin='',
58+
program_name=''):
5859
"""See connections.Connection.__init__() for information about
5960
defaults."""
6061
coro = _connect(host=host, user=user, password=password, db=db,
@@ -67,7 +68,7 @@ def connect(host="localhost", user=None, password="",
6768
read_default_group=read_default_group,
6869
no_delay=no_delay, autocommit=autocommit, echo=echo,
6970
local_infile=local_infile, loop=loop, ssl=ssl,
70-
auth_plugin=auth_plugin)
71+
auth_plugin=auth_plugin, program_name=program_name)
7172
return _ConnectionContextManager(coro)
7273

7374

@@ -91,7 +92,8 @@ def __init__(self, host="localhost", user=None, password="",
9192
client_flag=0, cursorclass=Cursor, init_command=None,
9293
connect_timeout=None, read_default_group=None,
9394
no_delay=None, autocommit=False, echo=False,
94-
local_infile=False, loop=None, ssl=None, auth_plugin=''):
95+
local_infile=False, loop=None, ssl=None, auth_plugin='',
96+
program_name=''):
9597
"""
9698
Establish a connection to the MySQL database. Accepts several
9799
arguments:
@@ -125,6 +127,13 @@ def __init__(self, host="localhost", user=None, password="",
125127
(default: False)
126128
:param local_infile: boolean to enable the use of LOAD DATA LOCAL
127129
command. (default: False)
130+
:param ssl: Optional SSL Context to force SSL
131+
:param auth_plugin: String to manually specify the authentication
132+
plugin to use, i.e you will want to use mysql_clear_password
133+
when using IAM authentication with Amazon RDS.
134+
(default: Server Default)
135+
:param program_name: Program name string to provide when
136+
handshaking with MySQL. (default: sys.argv[0])
128137
:param loop: asyncio loop
129138
"""
130139
self._loop = loop or asyncio.get_event_loop()
@@ -166,6 +175,17 @@ def __init__(self, host="localhost", user=None, password="",
166175
self._server_auth_plugin = ""
167176
self._auth_plugin_used = ""
168177

178+
# TODO somehow import version from __init__.py
179+
self._connect_attrs = {
180+
'_client_name': 'aiomysql',
181+
'_pid': str(os.getpid()),
182+
'_client_version': '0.0.16',
183+
}
184+
if program_name:
185+
self._connect_attrs["program_name"] = program_name
186+
elif sys.argv:
187+
self._connect_attrs["program_name"] = sys.argv[0]
188+
169189
self._unix_socket = unix_socket
170190
if charset:
171191
self._charset = charset
@@ -673,8 +693,10 @@ async def _request_authentication(self):
673693
charset_id = charset_by_name(self.charset).id
674694
if isinstance(self.user, str):
675695
_user = self.user.encode(self.encoding)
696+
else:
697+
_user = self.user
676698

677-
data_init = struct.pack('<iIB23s', self.client_flag, 1,
699+
data_init = struct.pack('<iIB23s', self.client_flag, MAX_PACKET_LEN,
678700
charset_id, b'')
679701

680702
data = data_init + _user + b'\0'
@@ -687,7 +709,8 @@ async def _request_authentication(self):
687709
auth_plugin = self._server_auth_plugin
688710

689711
if auth_plugin in ('', 'mysql_native_password'):
690-
authresp = _scramble(self._password.encode('latin1'), self.salt)
712+
authresp = _auth.scramble_native_password(
713+
self._password.encode('latin1'), self.salt)
691714
elif auth_plugin in ('', 'mysql_clear_password'):
692715
authresp = self._password.encode('latin1') + b'\0'
693716

@@ -715,6 +738,15 @@ async def _request_authentication(self):
715738

716739
self._auth_plugin_used = auth_plugin
717740

741+
# Sends the server a few pieces of client info
742+
if self.server_capabilities & CLIENT.CONNECT_ATTRS:
743+
connect_attrs = b''
744+
for k, v in self._connect_attrs.items():
745+
k, v = k.encode('utf8'), v.encode('utf8')
746+
connect_attrs += struct.pack('B', len(k)) + k
747+
connect_attrs += struct.pack('B', len(v)) + v
748+
data += struct.pack('B', len(connect_attrs)) + connect_attrs
749+
718750
self.write_packet(data)
719751
auth_packet = await self._read_packet()
720752

@@ -727,27 +759,28 @@ async def _request_authentication(self):
727759
plugin_name = auth_packet.read_string()
728760
if (self.server_capabilities & CLIENT.PLUGIN_AUTH and
729761
plugin_name is not None):
730-
auth_packet = await self._process_auth(
731-
plugin_name, auth_packet)
762+
await self._process_auth(plugin_name, auth_packet)
732763
else:
733764
# send legacy handshake
734-
data = _scramble_323(self._password.encode('latin1'),
735-
self.salt) + b'\0'
765+
data = _auth.scramble_old_password(
766+
self._password.encode('latin1'),
767+
auth_packet.read_all()) + b'\0'
736768
self.write_packet(data)
737-
auth_packet = await self._read_packet()
769+
await self._read_packet()
738770

739771
async def _process_auth(self, plugin_name, auth_packet):
740772
if plugin_name == b"mysql_native_password":
741773
# https://dev.mysql.com/doc/internals/en/
742774
# secure-password-authentication.html#packet-Authentication::
743775
# Native41
744-
data = _scramble(self._password.encode('latin1'),
745-
auth_packet.read_all())
776+
data = _auth.scramble_native_password(
777+
self._password.encode('latin1'),
778+
auth_packet.read_all())
746779
elif plugin_name == b"mysql_old_password":
747780
# https://dev.mysql.com/doc/internals/en/
748781
# old-password-authentication.html
749-
data = _scramble_323(self._password.encode('latin1'),
750-
auth_packet.read_all()) + b'\0'
782+
data = _auth.scramble_old_password(self._password.encode('latin1'),
783+
auth_packet.read_all()) + b'\0'
751784
elif plugin_name == b"mysql_clear_password":
752785
# https://dev.mysql.com/doc/internals/en/
753786
# clear-text-authentication.html

docs/connection.rst

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ Example::
4646
read_default_file=None, conv=decoders, use_unicode=None,
4747
client_flag=0, cursorclass=Cursor, init_command=None,
4848
connect_timeout=None, read_default_group=None,
49-
no_delay=False, autocommit=False, echo=False, loop=None)
49+
no_delay=False, autocommit=False, echo=False,
50+
ssl=None, auth_plugin='', program_name='', loop=None)
5051

5152
A :ref:`coroutine <coroutine>` that connects to MySQL.
5253

@@ -81,6 +82,13 @@ Example::
8182
:param bool no_delay: disable Nagle's algorithm on the socket
8283
:param autocommit: Autocommit mode. None means use server default.
8384
(default: ``False``)
85+
:param ssl: Optional SSL Context to force SSL
86+
:param auth_plugin: String to manually specify the authentication
87+
plugin to use, i.e you will want to use mysql_clear_password
88+
when using IAM authentication with Amazon RDS.
89+
(default: Server Default)
90+
:param program_name: Program name string to provide when
91+
handshaking with MySQL. (default: sys.argv[0])
8492
:param loop: asyncio event loop instance or ``None`` for default one.
8593
:returns: :class:`Connection` instance.
8694

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ ipython==6.4.0
55
pytest==3.6.1
66
pytest-cov==2.5.1
77
pytest-sugar==0.9.1
8-
PyMySQL>=0.7.5,<0.9
8+
PyMySQL>=0.9,<=0.9.2
99
docker==3.3.0
1010
sphinx==1.7.5
1111
sphinxcontrib-asyncio==0.2.0

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from setuptools import setup, find_packages
55

66

7-
install_requires = ['PyMySQL>=0.7.5,<0.9']
7+
install_requires = ['PyMySQL>=0.9,<=0.9.2']
88

99
PY_VER = sys.version_info
1010

tests/test_basic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@ async def f():
2222

2323
@pytest.mark.run_loop
2424
async def test_datatypes(connection, cursor, datatype_table):
25+
encoding = connection.charset
26+
if encoding == 'utf8mb4':
27+
encoding = 'utf8'
28+
2529
# insert values
2630
v = (
2731
True, -3, 123456789012, 5.7, "hello'\" world",
2832
u"Espa\xc3\xb1ol",
29-
"binary\x00data".encode(connection.charset),
33+
"binary\x00data".encode(encoding),
3034
datetime.date(1988, 2, 2),
3135
datetime.datetime.now().replace(microsecond=0),
3236
datetime.timedelta(5, 6), datetime.time(16, 32),

tests/test_connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test_connection_info_methods(self):
170170
conn = yield from self.connect()
171171
# trhead id is int
172172
self.assertIsInstance(conn.thread_id(), int)
173-
self.assertEqual(conn.character_set_name(), 'latin1')
173+
self.assertIn(conn.character_set_name(), ('latin1', 'utf8mb4'))
174174
self.assertTrue(str(conn.port) in conn.get_host_info())
175175
self.assertIsInstance(conn.get_server_info(), str)
176176
# protocol id is int
@@ -180,7 +180,7 @@ def test_connection_info_methods(self):
180180
@run_until_complete
181181
def test_connection_set_charset(self):
182182
conn = yield from self.connect()
183-
self.assertEqual(conn.character_set_name(), 'latin1')
183+
self.assertIn(conn.character_set_name(), ('latin1', 'utf8mb4'))
184184
yield from conn.set_charset('utf8')
185185
self.assertEqual(conn.character_set_name(), 'utf8')
186186

0 commit comments

Comments
 (0)