@@ -56,7 +56,7 @@ def connect(host="localhost", user=None, password="",
56
56
client_flag = 0 , cursorclass = Cursor , init_command = None ,
57
57
connect_timeout = None , read_default_group = None ,
58
58
no_delay = None , autocommit = False , echo = False ,
59
- local_infile = False , loop = None ):
59
+ local_infile = False , loop = None , ssl = None , auth_plugin = '' ):
60
60
"""See connections.Connection.__init__() for information about
61
61
defaults."""
62
62
coro = _connect (host = host , user = user , password = password , db = db ,
@@ -68,7 +68,8 @@ def connect(host="localhost", user=None, password="",
68
68
connect_timeout = connect_timeout ,
69
69
read_default_group = read_default_group ,
70
70
no_delay = no_delay , autocommit = autocommit , echo = echo ,
71
- local_infile = local_infile , loop = loop )
71
+ local_infile = local_infile , loop = loop , ssl = ssl ,
72
+ auth_plugin = auth_plugin )
72
73
return _ConnectionContextManager (coro )
73
74
74
75
@@ -93,7 +94,7 @@ def __init__(self, host="localhost", user=None, password="",
93
94
client_flag = 0 , cursorclass = Cursor , init_command = None ,
94
95
connect_timeout = None , read_default_group = None ,
95
96
no_delay = None , autocommit = False , echo = False ,
96
- local_infile = False , loop = None ):
97
+ local_infile = False , loop = None , ssl = None , auth_plugin = '' ):
97
98
"""
98
99
Establish a connection to the MySQL database. Accepts several
99
100
arguments:
@@ -164,6 +165,9 @@ def __init__(self, host="localhost", user=None, password="",
164
165
self ._no_delay = no_delay
165
166
self ._echo = echo
166
167
self ._last_usage = self ._loop .time ()
168
+ self ._client_auth_plugin = auth_plugin
169
+ self ._server_auth_plugin = ""
170
+ self ._auth_plugin_used = ""
167
171
168
172
self ._unix_socket = unix_socket
169
173
if charset :
@@ -176,6 +180,10 @@ def __init__(self, host="localhost", user=None, password="",
176
180
if use_unicode is not None :
177
181
self .use_unicode = use_unicode
178
182
183
+ self ._ssl_context = ssl
184
+ if ssl :
185
+ client_flag |= CLIENT .SSL
186
+
179
187
self ._encoding = charset_by_name (self ._charset ).encoding
180
188
181
189
if local_infile :
@@ -209,8 +217,6 @@ def __init__(self, host="localhost", user=None, password="",
209
217
# user
210
218
self ._close_reason = None
211
219
212
- self ._auth_plugin_name = ""
213
-
214
220
@property
215
221
def host (self ):
216
222
"""MySQL server IP address or name"""
@@ -663,6 +669,31 @@ def _request_authentication(self):
663
669
if self .user is None :
664
670
raise ValueError ("Did not specify a username" )
665
671
672
+ if self ._ssl_context :
673
+ # capablities, max packet, charset
674
+ data = struct .pack ('<IIB' , self .client_flag , 16777216 , 33 )
675
+ data += b'\x00 ' * (32 - len (data ))
676
+
677
+ self .write_packet (data )
678
+
679
+ # Stop sending events to data_received
680
+ self ._writer .transport .pause_reading ()
681
+
682
+ # Get the raw socket from the transport
683
+ raw_sock = self ._writer .transport .get_extra_info ('socket' ,
684
+ default = None )
685
+ if raw_sock is None :
686
+ raise RuntimeError ("Transport does not expose socket instance" )
687
+
688
+ # MySQL expects TLS negotiation to happen in the middle of a
689
+ # TCP connection not at start. Passing in a socket to
690
+ # open_connection will cause it to negotiate TLS on an existing
691
+ # connection not initiate a new one.
692
+ self ._reader , self ._writer = yield from asyncio .open_connection (
693
+ sock = raw_sock , ssl = self ._ssl_context , loop = self ._loop ,
694
+ server_hostname = self ._host
695
+ )
696
+
666
697
charset_id = charset_by_name (self .charset ).id
667
698
if isinstance (self .user , str ):
668
699
_user = self .user .encode (self .encoding )
@@ -673,8 +704,16 @@ def _request_authentication(self):
673
704
data = data_init + _user + b'\0 '
674
705
675
706
authresp = b''
676
- if self ._auth_plugin_name in ('' , 'mysql_native_password' ):
707
+
708
+ auth_plugin = self ._client_auth_plugin
709
+ if not self ._client_auth_plugin :
710
+ # Contains the auth plugin from handshake
711
+ auth_plugin = self ._server_auth_plugin
712
+
713
+ if auth_plugin in ('' , 'mysql_native_password' ):
677
714
authresp = _scramble (self ._password .encode ('latin1' ), self .salt )
715
+ elif auth_plugin in ('' , 'mysql_clear_password' ):
716
+ authresp = self ._password .encode ('latin1' ) + b'\0 '
678
717
679
718
if self .server_capabilities & CLIENT .PLUGIN_AUTH_LENENC_CLIENT_DATA :
680
719
data += lenenc_int (len (authresp )) + authresp
@@ -693,11 +732,13 @@ def _request_authentication(self):
693
732
data += db + b'\0 '
694
733
695
734
if self .server_capabilities & CLIENT .PLUGIN_AUTH :
696
- name = self . _auth_plugin_name
735
+ name = auth_plugin
697
736
if isinstance (name , str ):
698
737
name = name .encode ('ascii' )
699
738
data += name + b'\0 '
700
739
740
+ self ._auth_plugin_used = auth_plugin
741
+
701
742
self .write_packet (data )
702
743
auth_packet = yield from self ._read_packet ()
703
744
@@ -710,14 +751,45 @@ def _request_authentication(self):
710
751
plugin_name = auth_packet .read_string ()
711
752
if (self .server_capabilities & CLIENT .PLUGIN_AUTH and
712
753
plugin_name is not None ):
713
- auth_packet = self ._process_auth (plugin_name , auth_packet )
754
+ auth_packet = yield from self ._process_auth (
755
+ plugin_name , auth_packet )
714
756
else :
715
757
# send legacy handshake
716
758
data = _scramble_323 (self ._password .encode ('latin1' ),
717
759
self .salt ) + b'\0 '
718
760
self .write_packet (data )
719
761
auth_packet = yield from self ._read_packet ()
720
762
763
+ @asyncio .coroutine
764
+ def _process_auth (self , plugin_name , auth_packet ):
765
+ if plugin_name == b"mysql_native_password" :
766
+ # https://dev.mysql.com/doc/internals/en/
767
+ # secure-password-authentication.html#packet-Authentication::
768
+ # Native41
769
+ data = _scramble (self ._password .encode ('latin1' ),
770
+ auth_packet .read_all ())
771
+ elif plugin_name == b"mysql_old_password" :
772
+ # https://dev.mysql.com/doc/internals/en/
773
+ # old-password-authentication.html
774
+ data = _scramble_323 (self ._password .encode ('latin1' ),
775
+ auth_packet .read_all ()) + b'\0 '
776
+ elif plugin_name == b"mysql_clear_password" :
777
+ # https://dev.mysql.com/doc/internals/en/
778
+ # clear-text-authentication.html
779
+ data = self ._password .encode ('latin1' ) + b'\0 '
780
+ else :
781
+ raise OperationalError (
782
+ 2059 , "Authentication plugin '%s' not configured" % plugin_name
783
+ )
784
+
785
+ self .write_packet (data )
786
+ pkt = yield from self ._read_packet ()
787
+ pkt .check_error ()
788
+
789
+ self ._auth_plugin_used = plugin_name
790
+
791
+ return pkt
792
+
721
793
# _mysql support
722
794
def thread_id (self ):
723
795
return self .server_thread_id [0 ]
@@ -786,9 +858,9 @@ def _get_server_information(self):
786
858
server_end = data .find (b'\0 ' , i )
787
859
if server_end < 0 : # pragma: no cover - very specific upstream bug
788
860
# not found \0 and last field so take it all
789
- self ._auth_plugin_name = data [i :].decode ('latin1' )
861
+ self ._server_auth_plugin = data [i :].decode ('latin1' )
790
862
else :
791
- self ._auth_plugin_name = data [i :server_end ].decode ('latin1' )
863
+ self ._server_auth_plugin = data [i :server_end ].decode ('latin1' )
792
864
793
865
def get_transaction_status (self ):
794
866
return bool (self .server_status & SERVER_STATUS .SERVER_STATUS_IN_TRANS )
0 commit comments