Skip to content

Commit f87cb73

Browse files
feat: add metadata exchange to Python
1 parent e465eea commit f87cb73

File tree

3 files changed

+204
-1
lines changed

3 files changed

+204
-1
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 91 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@
2121
import logging
2222
import os
2323
import socket
24+
import struct
2425
from threading import Thread
2526
from types import TracebackType
26-
from typing import Any, Callable, Optional, Union
27+
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
2728

2829
import google.auth
2930
from google.auth.credentials import Credentials
@@ -44,11 +45,17 @@
4445
from google.cloud.sql.connector.resolver import DnsResolver
4546
from google.cloud.sql.connector.utils import format_database_user
4647
from google.cloud.sql.connector.utils import generate_keys
48+
import google.cloud.sql.proto.cloud_sql_metadata_exchange_pb2 as connectorspb
49+
50+
if TYPE_CHECKING:
51+
import ssl
4752

4853
logger = logging.getLogger(name=__name__)
4954

5055
ASYNC_DRIVERS = ["asyncpg"]
5156
SERVER_PROXY_PORT = 3307
57+
# the maximum amount of time to wait before aborting a metadata exchange
58+
IO_TIMEOUT = 30
5259
_DEFAULT_SCHEME = "https://"
5360
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
5461
_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"
@@ -391,6 +398,9 @@ async def connect_async(
391398
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
392399
server_hostname=ip_address,
393400
)
401+
# Perform Metadata Exchange Protocol
402+
metadata_partial = partial(self.metadata_exchange, sock)
403+
sock = await self._loop.run_in_executor(None, metadata_partial)
394404
# If this connection was opened using a domain name, then store it
395405
# for later in case we need to forcibly close it on failover.
396406
if conn_info.conn_name.domain_name:
@@ -409,6 +419,86 @@ async def connect_async(
409419
await monitored_cache.force_refresh()
410420
raise
411421

422+
def metadata_exchange(self, sock: ssl.SSLSocket) -> ssl.SSLSocket:
423+
"""
424+
Sends metadata about the connection prior to the database
425+
protocol taking over.
426+
The exchange consists of four steps:
427+
1. Prepare a CloudSQLConnectRequest including the socket protocol and
428+
the user agent.
429+
2. Write the size of the message as a big endian uint32 (4 bytes) to
430+
the server followed by the serialized message. The length does not
431+
include the initial four bytes.
432+
3. Read a big endian uint32 (4 bytes) from the server. This is the
433+
CloudSQLConnectResponse message length and does not include the
434+
initial four bytes.
435+
4. Parse the response using the message length in step 3. If the
436+
response is not OK, return the response's error. If there is no error,
437+
the metadata exchange has succeeded and the connection is complete.
438+
Args:
439+
sock (ssl.SSLSocket): The mTLS/SSL socket to perform metadata
440+
exchange on.
441+
Returns:
442+
sock (ssl.SSLSocket): mTLS/SSL socket connected to Cloud SQL Proxy
443+
server.
444+
"""
445+
# form metadata exchange request
446+
req = connectorspb.CloudSQLConnectRequest(
447+
user_agent=f"{self._client._user_agent}", # type: ignore
448+
protocol_type=connectorspb.CloudSQLConnectRequest.TCP,
449+
)
450+
451+
# set I/O timeout
452+
sock.settimeout(IO_TIMEOUT)
453+
454+
# pack big-endian unsigned integer (4 bytes)
455+
packed_len = struct.pack(">I", req.ByteSize())
456+
457+
# send metadata message length and request message
458+
sock.sendall(packed_len + req.SerializeToString())
459+
460+
# form metadata exchange response
461+
resp = connectorspb.CloudSQLConnectResponse()
462+
463+
# read metadata message length (4 bytes)
464+
message_len_buffer_size = struct.Struct(">I").size
465+
message_len_buffer = b""
466+
while message_len_buffer_size > 0:
467+
chunk = sock.recv(message_len_buffer_size)
468+
if not chunk:
469+
raise RuntimeError(
470+
"Connection closed while getting metadata exchange length!"
471+
)
472+
message_len_buffer += chunk
473+
message_len_buffer_size -= len(chunk)
474+
475+
(message_len,) = struct.unpack(">I", message_len_buffer)
476+
477+
# read metadata exchange message
478+
buffer = b""
479+
while message_len > 0:
480+
chunk = sock.recv(message_len)
481+
if not chunk:
482+
raise RuntimeError(
483+
"Connection closed while performing metadata exchange!"
484+
)
485+
buffer += chunk
486+
message_len -= len(chunk)
487+
488+
# parse metadata exchange response from buffer
489+
resp.ParseFromString(buffer)
490+
491+
# reset socket back to blocking mode
492+
sock.setblocking(True)
493+
494+
# validate metadata exchange response
495+
if resp.response_code != connectorspb.CloudSQLConnectResponse.OK:
496+
raise ValueError(
497+
f"Metadata Exchange request has failed with error: {resp.error}"
498+
)
499+
500+
return sock
501+
412502
async def _remove_cached(
413503
self, instance_connection_string: str, enable_iam_auth: bool
414504
) -> None:
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# -*- coding: utf-8 -*-
16+
# Generated by the protocol buffer compiler. DO NOT EDIT!
17+
# source: google/cloud/sql/v1beta4/cloud_sql_metadata_exchange.proto
18+
"""Generated protocol buffer code."""
19+
from google.protobuf import descriptor as _descriptor
20+
from google.protobuf import descriptor_pool as _descriptor_pool
21+
from google.protobuf import symbol_database as _symbol_database
22+
from google.protobuf.internal import builder as _builder
23+
24+
# @@protoc_insertion_point(imports)
25+
26+
_sym_db = _symbol_database.Default()
27+
28+
29+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
30+
b'\n:google/cloud/sql/v1beta4/cloud_sql_metadata_exchange.proto\x12\x18google.cloud.sql.v1beta4"\x90\x02\n\x16\x43loudSQLConnectRequest\x12\x14\n\x07version\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nuser_agent\x18\x02 \x01(\tH\x01\x88\x01\x01\x12Y\n\rprotocol_type\x18\x03 \x01(\x0e\x32=.google.cloud.sql.v1beta4.CloudSQLConnectRequest.ProtocolTypeH\x02\x88\x01\x01"?\n\x0cProtocolType\x12\x1d\n\x19PROTOCOL_TYPE_UNSPECIFIED\x10\x00\x12\x07\n\x03TCP\x10\x01\x12\x07\n\x03UDS\x10\x02\x42\n\n\x08_versionB\r\n\x0b_user_agentB\x10\n\x0e_protocol_type"\x8d\x01\n\x17\x43loudSQLConnectResponse\x12\x42\n\rresponse_code\x18\x01 \x01(\x0e\x32&.google.cloud.sql.v1beta4.ResponseCodeH\x00\x88\x01\x01\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x10\n\x0e_response_codeB\x08\n\x06_error*@\n\x0cResponseCode\x12\x1d\n\x19RESPONSE_CODE_UNSPECIFIED\x10\x00\x12\x06\n\x02OK\x10\x01\x12\t\n\x05\x45RROR\x10\x02\x62\x06proto3'
31+
)
32+
33+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
34+
_builder.BuildTopDescriptorsAndMessages(
35+
DESCRIPTOR, "google.cloud.sql.v1beta4.cloud_sql_metadata_exchange_pb2", globals()
36+
)
37+
if _descriptor._USE_C_DESCRIPTORS == False:
38+
39+
DESCRIPTOR._options = None
40+
_RESPONSECODE._serialized_start = 507
41+
_RESPONSECODE._serialized_end = 571
42+
_CLOUDSQLCONNECTREQUEST._serialized_start = 89
43+
_CLOUDSQLCONNECTREQUEST._serialized_end = 361
44+
_CLOUDSQLCONNECTREQUEST_PROTOCOLTYPE._serialized_start = 253
45+
_CLOUDSQLCONNECTREQUEST_PROTOCOLTYPE._serialized_end = 316
46+
_CLOUDSQLCONNECTRESPONSE._serialized_start = 364
47+
_CLOUDSQLCONNECTRESPONSE._serialized_end = 505
48+
# @@protoc_insertion_point(module_scope)
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import ClassVar as _ClassVar
16+
from typing import Optional as _Optional
17+
from typing import Union as _Union
18+
19+
from google.protobuf import descriptor as _descriptor
20+
from google.protobuf import message as _message
21+
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
22+
23+
DESCRIPTOR: _descriptor.FileDescriptor
24+
ERROR: ResponseCode
25+
OK: ResponseCode
26+
RESPONSE_CODE_UNSPECIFIED: ResponseCode
27+
28+
class CloudSQLConnectRequest(_message.Message):
29+
__slots__ = ["protocol_type", "user_agent", "version"]
30+
31+
class ProtocolType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
32+
__slots__ = []
33+
34+
PROTOCOL_TYPE_FIELD_NUMBER: _ClassVar[int]
35+
PROTOCOL_TYPE_UNSPECIFIED: CloudSQLConnectRequest.ProtocolType
36+
TCP: CloudSQLConnectRequest.ProtocolType
37+
UDS: CloudSQLConnectRequest.ProtocolType
38+
USER_AGENT_FIELD_NUMBER: _ClassVar[int]
39+
VERSION_FIELD_NUMBER: _ClassVar[int]
40+
protocol_type: CloudSQLConnectRequest.ProtocolType
41+
user_agent: str
42+
version: str
43+
def __init__(
44+
self,
45+
version: _Optional[str] = ...,
46+
user_agent: _Optional[str] = ...,
47+
protocol_type: _Optional[
48+
_Union[CloudSQLConnectRequest.ProtocolType, str]
49+
] = ...,
50+
) -> None: ...
51+
52+
class CloudSQLConnectResponse(_message.Message):
53+
__slots__ = ["error", "response_code"]
54+
ERROR_FIELD_NUMBER: _ClassVar[int]
55+
RESPONSE_CODE_FIELD_NUMBER: _ClassVar[int]
56+
error: str
57+
response_code: ResponseCode
58+
def __init__(
59+
self,
60+
response_code: _Optional[_Union[ResponseCode, str]] = ...,
61+
error: _Optional[str] = ...,
62+
) -> None: ...
63+
64+
class ResponseCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
65+
__slots__ = []

0 commit comments

Comments
 (0)