Skip to content

Commit d69c4a8

Browse files
committed
Provided client connect args
Supplied MAX_PACKET_LEN when handshaking with server For some reason 5.6+ didnt like it when we didnt supply client connect args PyMySQL now likes utf8mb4 not latin1 so that broke a few tests
1 parent 5715f94 commit d69c4a8

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

aiomysql/connection.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from .utils import _ConnectionContextManager, _ContextManager
4444
# from .log import logger
4545

46+
4647
DEFAULT_USER = getpass.getuser()
4748

4849

@@ -90,7 +91,8 @@ def __init__(self, host="localhost", user=None, password="",
9091
client_flag=0, cursorclass=Cursor, init_command=None,
9192
connect_timeout=None, read_default_group=None,
9293
no_delay=None, autocommit=False, echo=False,
93-
local_infile=False, loop=None, ssl=None, auth_plugin=''):
94+
local_infile=False, loop=None, ssl=None, auth_plugin='',
95+
program_name=''):
9496
"""
9597
Establish a connection to the MySQL database. Accepts several
9698
arguments:
@@ -165,6 +167,17 @@ def __init__(self, host="localhost", user=None, password="",
165167
self._server_auth_plugin = ""
166168
self._auth_plugin_used = ""
167169

170+
# TODO somehow import version from __init__.py
171+
self._connect_attrs = {
172+
'_client_name': 'aiomysql',
173+
'_pid': str(os.getpid()),
174+
'_client_version': '0.0.16',
175+
}
176+
if False or program_name:
177+
self._connect_attrs["program_name"] = program_name
178+
elif sys.argv:
179+
self._connect_attrs["program_name"] = sys.argv[0]
180+
168181
self._unix_socket = unix_socket
169182
if charset:
170183
self._charset = charset
@@ -672,8 +685,10 @@ async def _request_authentication(self):
672685
charset_id = charset_by_name(self.charset).id
673686
if isinstance(self.user, str):
674687
_user = self.user.encode(self.encoding)
688+
else:
689+
_user = self.user
675690

676-
data_init = struct.pack('<iIB23s', self.client_flag, 1,
691+
data_init = struct.pack('<iIB23s', self.client_flag, MAX_PACKET_LEN,
677692
charset_id, b'')
678693

679694
data = data_init + _user + b'\0'
@@ -686,8 +701,8 @@ async def _request_authentication(self):
686701
auth_plugin = self._server_auth_plugin
687702

688703
if auth_plugin in ('', 'mysql_native_password'):
689-
authresp = _auth.scramble_native_password(self._password,
690-
self.salt)
704+
authresp = _auth.scramble_native_password(
705+
self._password.encode('latin1'), self.salt)
691706
elif auth_plugin in ('', 'mysql_clear_password'):
692707
authresp = self._password.encode('latin1') + b'\0'
693708

@@ -715,6 +730,15 @@ async def _request_authentication(self):
715730

716731
self._auth_plugin_used = auth_plugin
717732

733+
if self.server_capabilities & CLIENT.CONNECT_ATTRS:
734+
connect_attrs = b''
735+
for k, v in self._connect_attrs.items():
736+
k = k.encode('utf8')
737+
connect_attrs += struct.pack('B', len(k)) + k
738+
v = v.encode('utf8')
739+
connect_attrs += struct.pack('B', len(v)) + v
740+
data += struct.pack('B', len(connect_attrs)) + connect_attrs
741+
718742
self.write_packet(data)
719743
auth_packet = await self._read_packet()
720744

@@ -732,7 +756,8 @@ async def _request_authentication(self):
732756
else:
733757
# send legacy handshake
734758
data = _auth.scramble_old_password(
735-
self._password, auth_packet.read_all()) + b'\0'
759+
self._password.encode('latin1'),
760+
auth_packet.read_all()) + b'\0'
736761
self.write_packet(data)
737762
auth_packet = await self._read_packet()
738763

@@ -741,12 +766,13 @@ async def _process_auth(self, plugin_name, auth_packet):
741766
# https://dev.mysql.com/doc/internals/en/
742767
# secure-password-authentication.html#packet-Authentication::
743768
# Native41
744-
data = _auth.scramble_native_password(self._password,
745-
auth_packet.read_all())
769+
data = _auth.scramble_native_password(
770+
self._password.encode('latin1'),
771+
auth_packet.read_all())
746772
elif plugin_name == b"mysql_old_password":
747773
# https://dev.mysql.com/doc/internals/en/
748774
# old-password-authentication.html
749-
data = _auth.scramble_old_password(self._password,
775+
data = _auth.scramble_old_password(self._password.encode('latin1'),
750776
auth_packet.read_all()) + b'\0'
751777
elif plugin_name == b"mysql_clear_password":
752778
# https://dev.mysql.com/doc/internals/en/

tests/test_basic.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,15 @@ async def f():
2222

2323
@pytest.mark.run_loop
2424
async def test_datatypes(connection, cursor, datatype_table):
25+
encoding = connection.charset
26+
if encoding == 'utf8mb4':
27+
encoding = 'utf8'
28+
2529
# insert values
2630
v = (
2731
True, -3, 123456789012, 5.7, "hello'\" world",
2832
u"Espa\xc3\xb1ol",
29-
"binary\x00data".encode(connection.charset),
33+
"binary\x00data".encode(encoding),
3034
datetime.date(1988, 2, 2),
3135
datetime.datetime.now().replace(microsecond=0),
3236
datetime.timedelta(5, 6), datetime.time(16, 32),

tests/test_connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def test_connection_info_methods(self):
170170
conn = yield from self.connect()
171171
# trhead id is int
172172
self.assertIsInstance(conn.thread_id(), int)
173-
self.assertEqual(conn.character_set_name(), 'latin1')
173+
self.assertIn(conn.character_set_name(), ('latin1', 'utf8mb4'))
174174
self.assertTrue(str(conn.port) in conn.get_host_info())
175175
self.assertIsInstance(conn.get_server_info(), str)
176176
# protocol id is int
@@ -180,7 +180,7 @@ def test_connection_info_methods(self):
180180
@run_until_complete
181181
def test_connection_set_charset(self):
182182
conn = yield from self.connect()
183-
self.assertEqual(conn.character_set_name(), 'latin1')
183+
self.assertIn(conn.character_set_name(), ('latin1', 'utf8mb4'))
184184
yield from conn.set_charset('utf8')
185185
self.assertEqual(conn.character_set_name(), 'utf8')
186186

0 commit comments

Comments
 (0)