41
41
# from aiomysql.utils import _convert_to_str
42
42
from .cursors import Cursor
43
43
from .utils import _ConnectionContextManager , _ContextManager
44
- # from .log import logger
44
+ from .log import logger
45
45
46
46
47
47
DEFAULT_USER = getpass .getuser ()
@@ -55,7 +55,7 @@ def connect(host="localhost", user=None, password="",
55
55
connect_timeout = None , read_default_group = None ,
56
56
no_delay = None , autocommit = False , echo = False ,
57
57
local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
58
- program_name = '' ):
58
+ program_name = '' , server_public_key = None ):
59
59
"""See connections.Connection.__init__() for information about
60
60
defaults."""
61
61
coro = _connect (host = host , user = user , password = password , db = db ,
@@ -93,7 +93,7 @@ def __init__(self, host="localhost", user=None, password="",
93
93
connect_timeout = None , read_default_group = None ,
94
94
no_delay = None , autocommit = False , echo = False ,
95
95
local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
96
- program_name = '' ):
96
+ program_name = '' , server_public_key = None ):
97
97
"""
98
98
Establish a connection to the MySQL database. Accepts several
99
99
arguments:
@@ -134,6 +134,8 @@ def __init__(self, host="localhost", user=None, password="",
134
134
(default: Server Default)
135
135
:param program_name: Program name string to provide when
136
136
handshaking with MySQL. (default: sys.argv[0])
137
+ :param server_public_key: SHA256 authentication plugin public
138
+ key value.
137
139
:param loop: asyncio loop
138
140
"""
139
141
self ._loop = loop or asyncio .get_event_loop ()
@@ -174,6 +176,8 @@ def __init__(self, host="localhost", user=None, password="",
174
176
self ._client_auth_plugin = auth_plugin
175
177
self ._server_auth_plugin = ""
176
178
self ._auth_plugin_used = ""
179
+ self .server_public_key = server_public_key
180
+ self .salt = None
177
181
178
182
# TODO somehow import version from __init__.py
179
183
self ._connect_attrs = {
@@ -712,6 +716,20 @@ async def _request_authentication(self):
712
716
if auth_plugin in ('' , 'mysql_native_password' ):
713
717
authresp = _auth .scramble_native_password (
714
718
self ._password .encode ('latin1' ), self .salt )
719
+ elif auth_plugin == 'caching_sha2_password' :
720
+ if self ._password :
721
+ authresp = _auth .scramble_caching_sha2 (
722
+ self ._password .encode ('latin1' ), self .salt
723
+ )
724
+ # Else: empty password
725
+ elif auth_plugin == 'sha256_password' :
726
+ if self ._ssl_context and self .server_capabilities & CLIENT .SSL :
727
+ authresp = self ._password .encode ('latin1' ) + b'\0 '
728
+ elif self ._password :
729
+ authresp = b'\1 ' # request public key
730
+ else :
731
+ authresp = b'\0 ' # empty password
732
+
715
733
elif auth_plugin in ('' , 'mysql_clear_password' ):
716
734
authresp = self ._password .encode ('latin1' ) + b'\0 '
717
735
@@ -768,35 +786,174 @@ async def _request_authentication(self):
768
786
auth_packet .read_all ()) + b'\0 '
769
787
self .write_packet (data )
770
788
await self ._read_packet ()
789
+ elif auth_packet .is_extra_auth_data ():
790
+ if auth_plugin == "caching_sha2_password" :
791
+ await self .caching_sha2_password_auth (auth_packet )
792
+ elif auth_plugin == "sha256_password" :
793
+ await self .sha256_password_auth (auth_packet )
794
+ else :
795
+ raise OperationalError ("Received extra packet "
796
+ "for auth method %r" , auth_plugin )
771
797
772
798
async def _process_auth (self , plugin_name , auth_packet ):
773
- if plugin_name == b"mysql_native_password" :
774
- # https://dev.mysql.com/doc/internals/en/
775
- # secure-password-authentication.html#packet-Authentication::
776
- # Native41
777
- data = _auth .scramble_native_password (
778
- self ._password .encode ('latin1' ),
779
- auth_packet .read_all ())
780
- elif plugin_name == b"mysql_old_password" :
781
- # https://dev.mysql.com/doc/internals/en/
782
- # old-password-authentication.html
783
- data = _auth .scramble_old_password (self ._password .encode ('latin1' ),
784
- auth_packet .read_all ()) + b'\0 '
785
- elif plugin_name == b"mysql_clear_password" :
786
- # https://dev.mysql.com/doc/internals/en/
787
- # clear-text-authentication.html
788
- data = self ._password .encode ('latin1' ) + b'\0 '
799
+ # These auth plugins do their own packet handling
800
+ if plugin_name == b"caching_sha2_password" :
801
+ await self .caching_sha2_password_auth (auth_packet )
802
+ self ._auth_plugin_used = plugin_name .decode ()
803
+ elif plugin_name == b"sha256_password" :
804
+ await self .sha256_password_auth (auth_packet )
805
+ self ._auth_plugin_used = plugin_name .decode ()
789
806
else :
807
+
808
+ if plugin_name == b"mysql_native_password" :
809
+ # https://dev.mysql.com/doc/internals/en/
810
+ # secure-password-authentication.html#packet-Authentication::
811
+ # Native41
812
+ data = _auth .scramble_native_password (
813
+ self ._password .encode ('latin1' ),
814
+ auth_packet .read_all ())
815
+ elif plugin_name == b"mysql_old_password" :
816
+ # https://dev.mysql.com/doc/internals/en/
817
+ # old-password-authentication.html
818
+ data = _auth .scramble_old_password (
819
+ self ._password .encode ('latin1' ),
820
+ auth_packet .read_all ()
821
+ ) + b'\0 '
822
+ elif plugin_name == b"mysql_clear_password" :
823
+ # https://dev.mysql.com/doc/internals/en/
824
+ # clear-text-authentication.html
825
+ data = self ._password .encode ('latin1' ) + b'\0 '
826
+ else :
827
+ raise OperationalError (
828
+ 2059 , "Authentication plugin '{0}'"
829
+ " not configured" .format (plugin_name )
830
+ )
831
+
832
+ self .write_packet (data )
833
+ pkt = await self ._read_packet ()
834
+ pkt .check_error ()
835
+
836
+ self ._auth_plugin_used = plugin_name .decode ()
837
+
838
+ return pkt
839
+
840
+ async def caching_sha2_password_auth (self , pkt ):
841
+ # No password fast path
842
+ if not self ._password :
843
+ self .write_packet (b'' )
844
+ pkt = await self ._read_packet ()
845
+ pkt .check_error ()
846
+ return pkt
847
+
848
+ if pkt .is_auth_switch_request ():
849
+ # Try from fast auth
850
+ logger .debug ("caching sha2: Trying fast path" )
851
+ self .salt = pkt .read_all ()
852
+ scrambled = _auth .scramble_caching_sha2 (
853
+ self ._password .encode ('latin1' ), self .salt
854
+ )
855
+
856
+ self .write_packet (scrambled )
857
+ pkt = await self ._read_packet ()
858
+ pkt .check_error ()
859
+
860
+ # else: fast auth is tried in initial handshake
861
+
862
+ if not pkt .is_extra_auth_data ():
790
863
raise OperationalError (
791
- 2059 , "Authentication plugin '%s' not configured" % plugin_name
864
+ "caching sha2: Unknown packet "
865
+ "for fast auth: {0}" .format (pkt ._data [:1 ])
792
866
)
793
867
868
+ # magic numbers:
869
+ # 2 - request public key
870
+ # 3 - fast auth succeeded
871
+ # 4 - need full auth
872
+
873
+ pkt .advance (1 )
874
+ n = pkt .read_uint8 ()
875
+
876
+ if n == 3 :
877
+ logger .debug ("caching sha2: succeeded by fast path." )
878
+ pkt = await self ._read_packet ()
879
+ pkt .check_error () # pkt must be OK packet
880
+ return pkt
881
+
882
+ if n != 4 :
883
+ raise OperationalError ("caching sha2: Unknown "
884
+ "result for fast auth: {0}" .format (n ))
885
+
886
+ logger .debug ("caching sha2: Trying full auth..." )
887
+
888
+ if self ._ssl_context :
889
+ logger .debug ("caching sha2: Sending plain "
890
+ "password via secure connection" )
891
+ self .write_packet (self ._password .encode ('latin1' ) + b'\0 ' )
892
+ pkt = await self ._read_packet ()
893
+ pkt .check_error ()
894
+ return pkt
895
+
896
+ if not self .server_public_key :
897
+ self .write_packet (b'\x02 ' )
898
+ pkt = await self ._read_packet () # Request public key
899
+ pkt .check_error ()
900
+
901
+ if not pkt .is_extra_auth_data ():
902
+ raise OperationalError (
903
+ "caching sha2: Unknown packet "
904
+ "for public key: {0}" .format (pkt ._data [:1 ])
905
+ )
906
+
907
+ self .server_public_key = pkt ._data [1 :]
908
+ logger .debug (self .server_public_key .decode ('ascii' ))
909
+
910
+ data = _auth .sha2_rsa_encrypt (
911
+ self ._password .encode ('latin1' ), self .salt ,
912
+ self .server_public_key
913
+ )
794
914
self .write_packet (data )
795
915
pkt = await self ._read_packet ()
796
916
pkt .check_error ()
797
917
798
- self ._auth_plugin_used = plugin_name
918
+ async def sha256_password_auth (self , pkt ):
919
+ if self ._ssl_context :
920
+ logger .debug ("sha256: Sending plain password" )
921
+ data = self ._password .encode ('latin1' ) + b'\0 '
922
+ self .write_packet (data )
923
+ pkt = await self ._read_packet ()
924
+ pkt .check_error ()
925
+ return pkt
926
+
927
+ if pkt .is_auth_switch_request ():
928
+ self .salt = pkt .read_all ()
929
+ if not self .server_public_key and self ._password :
930
+ # Request server public key
931
+ logger .debug ("sha256: Requesting server public key" )
932
+ self .write_packet (b'\1 ' )
933
+ pkt = await self ._read_packet ()
934
+ pkt .check_error ()
935
+
936
+ if pkt .is_extra_auth_data ():
937
+ self .server_public_key = pkt ._data [1 :]
938
+ logger .debug (
939
+ "Received public key:\n " ,
940
+ self .server_public_key .decode ('ascii' )
941
+ )
942
+
943
+ if self ._password :
944
+ if not self .server_public_key :
945
+ raise OperationalError ("Couldn't receive server's public key" )
946
+
947
+ data = _auth .sha2_rsa_encrypt (
948
+ self ._password .encode ('latin1' ), self .salt ,
949
+ self .server_public_key
950
+ )
951
+ else :
952
+ data = b''
799
953
954
+ self .write_packet (data )
955
+ pkt = await self ._read_packet ()
956
+ pkt .check_error ()
800
957
return pkt
801
958
802
959
# _mysql support
0 commit comments