41
41
# from aiomysql.utils import _convert_to_str
42
42
from .cursors import Cursor
43
43
from .utils import _ConnectionContextManager , _ContextManager
44
- # from .log import logger
44
+ from .log import logger
45
45
46
46
47
47
DEFAULT_USER = getpass .getuser ()
@@ -55,7 +55,7 @@ def connect(host="localhost", user=None, password="",
55
55
connect_timeout = None , read_default_group = None ,
56
56
no_delay = None , autocommit = False , echo = False ,
57
57
local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
58
- program_name = '' ):
58
+ program_name = '' , server_public_key = None ):
59
59
"""See connections.Connection.__init__() for information about
60
60
defaults."""
61
61
coro = _connect (host = host , user = user , password = password , db = db ,
@@ -93,7 +93,7 @@ def __init__(self, host="localhost", user=None, password="",
93
93
connect_timeout = None , read_default_group = None ,
94
94
no_delay = None , autocommit = False , echo = False ,
95
95
local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
96
- program_name = '' ):
96
+ program_name = '' , server_public_key = None ):
97
97
"""
98
98
Establish a connection to the MySQL database. Accepts several
99
99
arguments:
@@ -134,6 +134,8 @@ def __init__(self, host="localhost", user=None, password="",
134
134
(default: Server Default)
135
135
:param program_name: Program name string to provide when
136
136
handshaking with MySQL. (default: sys.argv[0])
137
+ :param server_public_key: SHA256 authentication plugin public
138
+ key value.
137
139
:param loop: asyncio loop
138
140
"""
139
141
self ._loop = loop or asyncio .get_event_loop ()
@@ -174,6 +176,8 @@ def __init__(self, host="localhost", user=None, password="",
174
176
self ._client_auth_plugin = auth_plugin
175
177
self ._server_auth_plugin = ""
176
178
self ._auth_plugin_used = ""
179
+ self .server_public_key = server_public_key
180
+ self .salt = None
177
181
178
182
# TODO somehow import version from __init__.py
179
183
self ._connect_attrs = {
@@ -711,6 +715,20 @@ async def _request_authentication(self):
711
715
if auth_plugin in ('' , 'mysql_native_password' ):
712
716
authresp = _auth .scramble_native_password (
713
717
self ._password .encode ('latin1' ), self .salt )
718
+ elif auth_plugin == 'caching_sha2_password' :
719
+ if self ._password :
720
+ authresp = _auth .scramble_caching_sha2 (
721
+ self ._password .encode ('latin1' ), self .salt
722
+ )
723
+ # Else: empty password
724
+ elif auth_plugin == 'sha256_password' :
725
+ if self ._ssl_context and self .server_capabilities & CLIENT .SSL :
726
+ authresp = self ._password .encode ('latin1' ) + b'\0 '
727
+ elif self ._password :
728
+ authresp = b'\1 ' # request public key
729
+ else :
730
+ authresp = b'\0 ' # empty password
731
+
714
732
elif auth_plugin in ('' , 'mysql_clear_password' ):
715
733
authresp = self ._password .encode ('latin1' ) + b'\0 '
716
734
@@ -767,9 +785,21 @@ async def _request_authentication(self):
767
785
auth_packet .read_all ()) + b'\0 '
768
786
self .write_packet (data )
769
787
await self ._read_packet ()
788
+ elif auth_packet .is_extra_auth_data ():
789
+ if auth_plugin == "caching_sha2_password" :
790
+ await self .caching_sha2_password_auth (auth_packet )
791
+ elif auth_plugin == "sha256_password" :
792
+ await self .sha256_password_auth (auth_packet )
793
+ else :
794
+ raise OperationalError ("Received extra packet "
795
+ "for auth method %r" , auth_plugin )
770
796
771
797
async def _process_auth (self , plugin_name , auth_packet ):
772
- if plugin_name == b"mysql_native_password" :
798
+ if plugin_name == b"caching_sha2_password" :
799
+ return self .caching_sha2_password_auth (auth_packet )
800
+ elif plugin_name == b"sha256_password" :
801
+ return self .sha256_password_auth (auth_packet )
802
+ elif plugin_name == b"mysql_native_password" :
773
803
# https://dev.mysql.com/doc/internals/en/
774
804
# secure-password-authentication.html#packet-Authentication::
775
805
# Native41
@@ -798,6 +828,125 @@ async def _process_auth(self, plugin_name, auth_packet):
798
828
799
829
return pkt
800
830
831
+ async def caching_sha2_password_auth (self , pkt ):
832
+ # No password fast path
833
+ if not self ._password :
834
+ self .write_packet (b'' )
835
+ pkt = await self ._read_packet ()
836
+ pkt .check_error ()
837
+ return pkt
838
+
839
+ if pkt .is_auth_switch_request ():
840
+ # Try from fast auth
841
+ logger .debug ("caching sha2: Trying fast path" )
842
+ self .salt = pkt .read_all ()
843
+ scrambled = _auth .scramble_caching_sha2 (
844
+ self ._password .encode ('latin1' ), self .salt
845
+ )
846
+
847
+ self .write_packet (scrambled )
848
+ pkt = await self ._read_packet ()
849
+ pkt .check_error ()
850
+
851
+ # else: fast auth is tried in initial handshake
852
+
853
+ if not pkt .is_extra_auth_data ():
854
+ raise OperationalError (
855
+ "caching sha2: Unknown packet "
856
+ "for fast auth: {0}" .format (pkt ._data [:1 ])
857
+ )
858
+
859
+ # magic numbers:
860
+ # 2 - request public key
861
+ # 3 - fast auth succeeded
862
+ # 4 - need full auth
863
+
864
+ pkt .advance (1 )
865
+ n = pkt .read_uint8 ()
866
+
867
+ if n == 3 :
868
+ logger .debug ("caching sha2: succeeded by fast path." )
869
+ pkt = await self ._read_packet ()
870
+ pkt .check_error () # pkt must be OK packet
871
+ return pkt
872
+
873
+ if n != 4 :
874
+ raise OperationalError ("caching sha2: Unknown "
875
+ "result for fast auth: {0}" .format (n ))
876
+
877
+ logger .debug ("caching sha2: Trying full auth..." )
878
+
879
+ if self ._ssl_context :
880
+ logger .debug ("caching sha2: Sending plain "
881
+ "password via secure connection" )
882
+ self .write_packet (self ._password .encode ('latin1' ) + b'\0 ' )
883
+ pkt = await self ._read_packet ()
884
+ pkt .check_error ()
885
+ return pkt
886
+
887
+ if not self .server_public_key :
888
+ self .write_packet (b'\x02 ' )
889
+ pkt = await self ._read_packet () # Request public key
890
+ pkt .check_error ()
891
+
892
+ if not pkt .is_extra_auth_data ():
893
+ raise OperationalError (
894
+ "caching sha2: Unknown packet "
895
+ "for public key: {0}" .format (pkt ._data [:1 ])
896
+ )
897
+
898
+ self .server_public_key = pkt ._data [1 :]
899
+ logger .debug (self .server_public_key .decode ('ascii' ))
900
+
901
+ data = _auth .sha2_rsa_encrypt (
902
+ self ._password .encode ('latin1' ), self .salt ,
903
+ self .server_public_key
904
+ )
905
+ self .write_packet (data )
906
+ pkt = await self ._read_packet ()
907
+ pkt .check_error ()
908
+
909
+ async def sha256_password_auth (self , pkt ):
910
+ if self ._ssl_context :
911
+ logger .debug ("sha256: Sending plain password" )
912
+ data = self ._password .encode ('latin1' ) + b'\0 '
913
+ self .write_packet (data )
914
+ pkt = await self ._read_packet ()
915
+ pkt .check_error ()
916
+ return pkt
917
+
918
+ if pkt .is_auth_switch_request ():
919
+ self .salt = pkt .read_all ()
920
+ if not self .server_public_key and self ._password :
921
+ # Request server public key
922
+ logger .debug ("sha256: Requesting server public key" )
923
+ self .write_packet (b'\1 ' )
924
+ pkt = await self ._read_packet ()
925
+ pkt .check_error ()
926
+
927
+ if pkt .is_extra_auth_data ():
928
+ self .server_public_key = pkt ._data [1 :]
929
+ logger .debug (
930
+ "Received public key:\n " ,
931
+ self .server_public_key .decode ('ascii' )
932
+ )
933
+
934
+ if self ._password :
935
+ if not self .server_public_key :
936
+ raise OperationalError ("Couldn't receive server's public key" )
937
+
938
+ data = _auth .sha2_rsa_encrypt (
939
+ self ._password .encode ('latin1' ), self .salt ,
940
+ self .server_public_key
941
+ )
942
+ else :
943
+ data = b''
944
+
945
+ self .write_packet (data )
946
+ pkt = await self ._read_packet ()
947
+ pkt .check_error ()
948
+ return pkt
949
+
801
950
# _mysql support
802
951
def thread_id (self ):
803
952
return self .server_thread_id [0 ]
0 commit comments