2121import logging
2222import os
2323import socket
24+ import struct
2425from threading import Thread
2526from types import TracebackType
26- from typing import Any , Callable , Optional , Union
27+ from typing import Any , Callable , Optional , TYPE_CHECKING , Union
2728
2829import google .auth
2930from google .auth .credentials import Credentials
4445from google .cloud .sql .connector .resolver import DnsResolver
4546from google .cloud .sql .connector .utils import format_database_user
4647from google .cloud .sql .connector .utils import generate_keys
48+ import google .cloud .sql .proto .cloud_sql_metadata_exchange_pb2 as connectorspb
49+
50+ if TYPE_CHECKING :
51+ import ssl
4752
4853logger = logging .getLogger (name = __name__ )
4954
5055ASYNC_DRIVERS = ["asyncpg" ]
5156SERVER_PROXY_PORT = 3307
57+ # the maximum amount of time to wait before aborting a metadata exchange
58+ IO_TIMEOUT = 30
5259_DEFAULT_SCHEME = "https://"
5360_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
5461_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"
@@ -391,6 +398,9 @@ async def connect_async(
391398 socket .create_connection ((ip_address , SERVER_PROXY_PORT )),
392399 server_hostname = ip_address ,
393400 )
401+ # Perform Metadata Exchange Protocol
402+ metadata_partial = partial (self .metadata_exchange , sock )
403+ sock = await self ._loop .run_in_executor (None , metadata_partial )
394404 # If this connection was opened using a domain name, then store it
395405 # for later in case we need to forcibly close it on failover.
396406 if conn_info .conn_name .domain_name :
@@ -409,6 +419,86 @@ async def connect_async(
409419 await monitored_cache .force_refresh ()
410420 raise
411421
422+ def metadata_exchange (self , sock : ssl .SSLSocket ) -> ssl .SSLSocket :
423+ """
424+ Sends metadata about the connection prior to the database
425+ protocol taking over.
426+ The exchange consists of four steps:
427+ 1. Prepare a CloudSQLConnectRequest including the socket protocol and
428+ the user agent.
429+ 2. Write the size of the message as a big endian uint32 (4 bytes) to
430+ the server followed by the serialized message. The length does not
431+ include the initial four bytes.
432+ 3. Read a big endian uint32 (4 bytes) from the server. This is the
433+ CloudSQLConnectResponse message length and does not include the
434+ initial four bytes.
435+ 4. Parse the response using the message length in step 3. If the
436+ response is not OK, return the response's error. If there is no error,
437+ the metadata exchange has succeeded and the connection is complete.
438+ Args:
439+ sock (ssl.SSLSocket): The mTLS/SSL socket to perform metadata
440+ exchange on.
441+ Returns:
442+ sock (ssl.SSLSocket): mTLS/SSL socket connected to Cloud SQL Proxy
443+ server.
444+ """
445+ # form metadata exchange request
446+ req = connectorspb .CloudSQLConnectRequest (
447+ user_agent = f"{ self ._client ._user_agent } " , # type: ignore
448+ protocol_type = connectorspb .CloudSQLConnectRequest .TCP ,
449+ )
450+
451+ # set I/O timeout
452+ sock .settimeout (IO_TIMEOUT )
453+
454+ # pack big-endian unsigned integer (4 bytes)
455+ packed_len = struct .pack (">I" , req .ByteSize ())
456+
457+ # send metadata message length and request message
458+ sock .sendall (packed_len + req .SerializeToString ())
459+
460+ # form metadata exchange response
461+ resp = connectorspb .CloudSQLConnectResponse ()
462+
463+ # read metadata message length (4 bytes)
464+ message_len_buffer_size = struct .Struct (">I" ).size
465+ message_len_buffer = b""
466+ while message_len_buffer_size > 0 :
467+ chunk = sock .recv (message_len_buffer_size )
468+ if not chunk :
469+ raise RuntimeError (
470+ "Connection closed while getting metadata exchange length!"
471+ )
472+ message_len_buffer += chunk
473+ message_len_buffer_size -= len (chunk )
474+
475+ (message_len ,) = struct .unpack (">I" , message_len_buffer )
476+
477+ # read metadata exchange message
478+ buffer = b""
479+ while message_len > 0 :
480+ chunk = sock .recv (message_len )
481+ if not chunk :
482+ raise RuntimeError (
483+ "Connection closed while performing metadata exchange!"
484+ )
485+ buffer += chunk
486+ message_len -= len (chunk )
487+
488+ # parse metadata exchange response from buffer
489+ resp .ParseFromString (buffer )
490+
491+ # reset socket back to blocking mode
492+ sock .setblocking (True )
493+
494+ # validate metadata exchange response
495+ if resp .response_code != connectorspb .CloudSQLConnectResponse .OK :
496+ raise ValueError (
497+ f"Metadata Exchange request has failed with error: { resp .error } "
498+ )
499+
500+ return sock
501+
412502 async def _remove_cached (
413503 self , instance_connection_string : str , enable_iam_auth : bool
414504 ) -> None :
0 commit comments