@@ -53,7 +53,7 @@ def connect(host="localhost", user=None, password="",
5353 connect_timeout = None , read_default_group = None ,
5454 autocommit = False , echo = False ,
5555 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
56- program_name = '' , server_public_key = None ):
56+ program_name = '' , server_public_key = None , implicit_tls = False ):
5757 """See connections.Connection.__init__() for information about
5858 defaults."""
5959 coro = _connect (host = host , user = user , password = password , db = db ,
@@ -66,7 +66,8 @@ def connect(host="localhost", user=None, password="",
6666 read_default_group = read_default_group ,
6767 autocommit = autocommit , echo = echo ,
6868 local_infile = local_infile , loop = loop , ssl = ssl ,
69- auth_plugin = auth_plugin , program_name = program_name )
69+ auth_plugin = auth_plugin , program_name = program_name ,
70+ implicit_tls = implicit_tls )
7071 return _ConnectionContextManager (coro )
7172
7273
@@ -142,7 +143,7 @@ def __init__(self, host="localhost", user=None, password="",
142143 connect_timeout = None , read_default_group = None ,
143144 autocommit = False , echo = False ,
144145 local_infile = False , loop = None , ssl = None , auth_plugin = '' ,
145- program_name = '' , server_public_key = None ):
146+ program_name = '' , server_public_key = None , implicit_tls = False ):
146147 """
147148 Establish a connection to the MySQL database. Accepts several
148149 arguments:
@@ -184,6 +185,9 @@ def __init__(self, host="localhost", user=None, password="",
184185 handshaking with MySQL. (omitted by default)
185186 :param server_public_key: SHA256 authentication plugin public
186187 key value.
188+ :param implicit_tls: Establish TLS immediately, skipping non-TLS
189+ preamble before upgrading to TLS.
190+ (default: False)
187191 :param loop: asyncio loop
188192 """
189193 self ._loop = loop or asyncio .get_event_loop ()
@@ -218,6 +222,7 @@ def __init__(self, host="localhost", user=None, password="",
218222 self ._auth_plugin_used = ""
219223 self ._secure = False
220224 self .server_public_key = server_public_key
225+ self ._implicit_tls = implicit_tls
221226 self .salt = None
222227
223228 from . import __version__
@@ -241,7 +246,7 @@ def __init__(self, host="localhost", user=None, password="",
241246 self .use_unicode = use_unicode
242247
243248 self ._ssl_context = ssl
244- if ssl :
249+ if ssl and not implicit_tls :
245250 client_flag |= CLIENT .SSL
246251
247252 self ._encoding = charset_by_name (self ._charset ).encoding
@@ -536,7 +541,8 @@ async def _connect(self):
536541
537542 self ._next_seq_id = 0
538543
539- await self ._get_server_information ()
544+ if not self ._implicit_tls :
545+ await self ._get_server_information ()
540546 await self ._request_authentication ()
541547
542548 self .connected_time = self ._loop .time ()
@@ -727,7 +733,8 @@ async def _execute_command(self, command, sql):
727733
728734 async def _request_authentication (self ):
729735 # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
730- if int (self .server_version .split ('.' , 1 )[0 ]) >= 5 :
736+ # FIXME: change this before merge
737+ if self ._implicit_tls or int (self .server_version .split ('.' , 1 )[0 ]) >= 5 :
731738 self .client_flag |= CLIENT .MULTI_RESULTS
732739
733740 if self .user is None :
@@ -737,8 +744,10 @@ async def _request_authentication(self):
737744 data_init = struct .pack ('<iIB23s' , self .client_flag , MAX_PACKET_LEN ,
738745 charset_id , b'' )
739746
740- if self ._ssl_context and self .server_capabilities & CLIENT .SSL :
741- self .write_packet (data_init )
747+ if self ._ssl_context and \
748+ (self ._implicit_tls or self .server_capabilities & CLIENT .SSL ):
749+ if not self ._implicit_tls :
750+ self .write_packet (data_init )
742751
743752 # Stop sending events to data_received
744753 self ._writer .transport .pause_reading ()
@@ -760,6 +769,9 @@ async def _request_authentication(self):
760769 server_hostname = self ._host
761770 )
762771
772+ if self ._implicit_tls :
773+ await self ._get_server_information ()
774+
763775 self ._secure = True
764776
765777 if isinstance (self .user , str ):
0 commit comments