26
26
ProgrammingError )
27
27
28
28
from pymysql .connections import TEXT_TYPES , MAX_PACKET_LEN , DEFAULT_CHARSET
29
- # from pymysql.connections import dump_packet
30
- from pymysql .connections import _scramble
31
- from pymysql .connections import _scramble_323
29
+ from pymysql .connections import _auth
30
+
32
31
from pymysql .connections import pack_int24
33
32
34
33
from pymysql .connections import MysqlPacket
44
43
from .utils import _ConnectionContextManager , _ContextManager
45
44
# from .log import logger
46
45
46
+
47
47
DEFAULT_USER = getpass .getuser ()
48
48
49
49
@@ -54,7 +54,8 @@ def connect(host="localhost", user=None, password="",
54
54
client_flag = 0 , cursorclass = Cursor , init_command = None ,
55
55
connect_timeout = None , read_default_group = None ,
56
56
no_delay = None , autocommit = False , echo = False ,
57
- local_infile = False , loop = None , ssl = None , auth_plugin = '' ):
57
+ local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
58
+ program_name = '' ):
58
59
"""See connections.Connection.__init__() for information about
59
60
defaults."""
60
61
coro = _connect (host = host , user = user , password = password , db = db ,
@@ -67,7 +68,7 @@ def connect(host="localhost", user=None, password="",
67
68
read_default_group = read_default_group ,
68
69
no_delay = no_delay , autocommit = autocommit , echo = echo ,
69
70
local_infile = local_infile , loop = loop , ssl = ssl ,
70
- auth_plugin = auth_plugin )
71
+ auth_plugin = auth_plugin , program_name = program_name )
71
72
return _ConnectionContextManager (coro )
72
73
73
74
@@ -91,7 +92,8 @@ def __init__(self, host="localhost", user=None, password="",
91
92
client_flag = 0 , cursorclass = Cursor , init_command = None ,
92
93
connect_timeout = None , read_default_group = None ,
93
94
no_delay = None , autocommit = False , echo = False ,
94
- local_infile = False , loop = None , ssl = None , auth_plugin = '' ):
95
+ local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
96
+ program_name = '' ):
95
97
"""
96
98
Establish a connection to the MySQL database. Accepts several
97
99
arguments:
@@ -125,6 +127,13 @@ def __init__(self, host="localhost", user=None, password="",
125
127
(default: False)
126
128
:param local_infile: boolean to enable the use of LOAD DATA LOCAL
127
129
command. (default: False)
130
+ :param ssl: Optional SSL Context to force SSL
131
+ :param auth_plugin: String to manually specify the authentication
132
+ plugin to use, i.e you will want to use mysql_clear_password
133
+ when using IAM authentication with Amazon RDS.
134
+ (default: Server Default)
135
+ :param program_name: Program name string to provide when
136
+ handshaking with MySQL. (default: sys.argv[0])
128
137
:param loop: asyncio loop
129
138
"""
130
139
self ._loop = loop or asyncio .get_event_loop ()
@@ -166,6 +175,17 @@ def __init__(self, host="localhost", user=None, password="",
166
175
self ._server_auth_plugin = ""
167
176
self ._auth_plugin_used = ""
168
177
178
+ # TODO somehow import version from __init__.py
179
+ self ._connect_attrs = {
180
+ '_client_name' : 'aiomysql' ,
181
+ '_pid' : str (os .getpid ()),
182
+ '_client_version' : '0.0.16' ,
183
+ }
184
+ if program_name :
185
+ self ._connect_attrs ["program_name" ] = program_name
186
+ elif sys .argv :
187
+ self ._connect_attrs ["program_name" ] = sys .argv [0 ]
188
+
169
189
self ._unix_socket = unix_socket
170
190
if charset :
171
191
self ._charset = charset
@@ -673,8 +693,10 @@ async def _request_authentication(self):
673
693
charset_id = charset_by_name (self .charset ).id
674
694
if isinstance (self .user , str ):
675
695
_user = self .user .encode (self .encoding )
696
+ else :
697
+ _user = self .user
676
698
677
- data_init = struct .pack ('<iIB23s' , self .client_flag , 1 ,
699
+ data_init = struct .pack ('<iIB23s' , self .client_flag , MAX_PACKET_LEN ,
678
700
charset_id , b'' )
679
701
680
702
data = data_init + _user + b'\0 '
@@ -687,7 +709,8 @@ async def _request_authentication(self):
687
709
auth_plugin = self ._server_auth_plugin
688
710
689
711
if auth_plugin in ('' , 'mysql_native_password' ):
690
- authresp = _scramble (self ._password .encode ('latin1' ), self .salt )
712
+ authresp = _auth .scramble_native_password (
713
+ self ._password .encode ('latin1' ), self .salt )
691
714
elif auth_plugin in ('' , 'mysql_clear_password' ):
692
715
authresp = self ._password .encode ('latin1' ) + b'\0 '
693
716
@@ -715,6 +738,15 @@ async def _request_authentication(self):
715
738
716
739
self ._auth_plugin_used = auth_plugin
717
740
741
+ # Sends the server a few pieces of client info
742
+ if self .server_capabilities & CLIENT .CONNECT_ATTRS :
743
+ connect_attrs = b''
744
+ for k , v in self ._connect_attrs .items ():
745
+ k , v = k .encode ('utf8' ), v .encode ('utf8' )
746
+ connect_attrs += struct .pack ('B' , len (k )) + k
747
+ connect_attrs += struct .pack ('B' , len (v )) + v
748
+ data += struct .pack ('B' , len (connect_attrs )) + connect_attrs
749
+
718
750
self .write_packet (data )
719
751
auth_packet = await self ._read_packet ()
720
752
@@ -727,27 +759,28 @@ async def _request_authentication(self):
727
759
plugin_name = auth_packet .read_string ()
728
760
if (self .server_capabilities & CLIENT .PLUGIN_AUTH and
729
761
plugin_name is not None ):
730
- auth_packet = await self ._process_auth (
731
- plugin_name , auth_packet )
762
+ await self ._process_auth (plugin_name , auth_packet )
732
763
else :
733
764
# send legacy handshake
734
- data = _scramble_323 (self ._password .encode ('latin1' ),
735
- self .salt ) + b'\0 '
765
+ data = _auth .scramble_old_password (
766
+ self ._password .encode ('latin1' ),
767
+ auth_packet .read_all ()) + b'\0 '
736
768
self .write_packet (data )
737
- auth_packet = await self ._read_packet ()
769
+ await self ._read_packet ()
738
770
739
771
async def _process_auth (self , plugin_name , auth_packet ):
740
772
if plugin_name == b"mysql_native_password" :
741
773
# https://dev.mysql.com/doc/internals/en/
742
774
# secure-password-authentication.html#packet-Authentication::
743
775
# Native41
744
- data = _scramble (self ._password .encode ('latin1' ),
745
- auth_packet .read_all ())
776
+ data = _auth .scramble_native_password (
777
+ self ._password .encode ('latin1' ),
778
+ auth_packet .read_all ())
746
779
elif plugin_name == b"mysql_old_password" :
747
780
# https://dev.mysql.com/doc/internals/en/
748
781
# old-password-authentication.html
749
- data = _scramble_323 (self ._password .encode ('latin1' ),
750
- auth_packet .read_all ()) + b'\0 '
782
+ data = _auth . scramble_old_password (self ._password .encode ('latin1' ),
783
+ auth_packet .read_all ()) + b'\0 '
751
784
elif plugin_name == b"mysql_clear_password" :
752
785
# https://dev.mysql.com/doc/internals/en/
753
786
# clear-text-authentication.html
0 commit comments