Skip to content

Commit f3fb683

Browse files
chore: add metadata exchange to tests
1 parent 608250a commit f3fb683

File tree

13 files changed

+211
-34
lines changed

13 files changed

+211
-34
lines changed

google/api/field_behavior_pb2.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# type: ignore
16+
# -*- coding: utf-8 -*-
17+
# Generated by the protocol buffer compiler. DO NOT EDIT!
18+
# source: google/api/field_behavior.proto
19+
# isort: skip_file
20+
"""Generated protocol buffer code."""
21+
from google.protobuf import descriptor as _descriptor
22+
from google.protobuf import descriptor_pool as _descriptor_pool
23+
from google.protobuf import symbol_database as _symbol_database
24+
from google.protobuf.internal import builder as _builder
25+
26+
# @@protoc_insertion_point(imports)
27+
28+
_sym_db = _symbol_database.Default()
29+
30+
31+
from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2
32+
33+
34+
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
35+
b"\n\x1fgoogle/api/field_behavior.proto\x12\ngoogle.api\x1a google/protobuf/descriptor.proto*\xb6\x01\n\rFieldBehavior\x12\x1e\n\x1a\x46IELD_BEHAVIOR_UNSPECIFIED\x10\x00\x12\x0c\n\x08OPTIONAL\x10\x01\x12\x0c\n\x08REQUIRED\x10\x02\x12\x0f\n\x0bOUTPUT_ONLY\x10\x03\x12\x0e\n\nINPUT_ONLY\x10\x04\x12\r\n\tIMMUTABLE\x10\x05\x12\x12\n\x0eUNORDERED_LIST\x10\x06\x12\x15\n\x11NON_EMPTY_DEFAULT\x10\x07\x12\x0e\n\nIDENTIFIER\x10\x08:Q\n\x0e\x66ield_behavior\x12\x1d.google.protobuf.FieldOptions\x18\x9c\x08 \x03(\x0e\x32\x19.google.api.FieldBehaviorBp\n\x0e\x63om.google.apiB\x12\x46ieldBehaviorProtoP\x01ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\xa2\x02\x04GAPIb\x06proto3"
36+
)
37+
38+
_globals = globals()
39+
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
40+
_builder.BuildTopDescriptorsAndMessages(
41+
DESCRIPTOR, "google.api.field_behavior_pb2", _globals
42+
)
43+
if _descriptor._USE_C_DESCRIPTORS == False:
44+
google_dot_protobuf_dot_descriptor__pb2.FieldOptions.RegisterExtension(
45+
field_behavior
46+
)
47+
48+
DESCRIPTOR._options = None
49+
DESCRIPTOR._serialized_options = b"\n\016com.google.apiB\022FieldBehaviorProtoP\001ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\242\002\004GAPI"
50+
_globals["_FIELDBEHAVIOR"]._serialized_start = 82
51+
_globals["_FIELDBEHAVIOR"]._serialized_end = 264
52+
# @@protoc_insertion_point(module_scope)

google/cloud/sql/connector/client.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def __init__(
9191
self._sqladmin_api_endpoint = DEFAULT_SERVICE_ENDPOINT
9292
else:
9393
self._sqladmin_api_endpoint = sqladmin_api_endpoint
94+
# asyncpg does not currently support using metadata exchange
95+
# only use metadata exchange for sync drivers
96+
self._use_metadata = False if driver == "asyncpg" else True
9497
self._user_agent = user_agent
9598

9699
async def _get_metadata(
@@ -204,7 +207,10 @@ async def _get_ephemeral(
204207

205208
url = f"{self._sqladmin_api_endpoint}/sql/{API_VERSION}/projects/{project}/instances/{instance}:generateEphemeralCert"
206209

207-
data = {"public_key": pub_key}
210+
data = {
211+
"public_key": pub_key,
212+
"use_metadata_exchange": self._use_metadata,
213+
}
208214

209215
if enable_iam_auth:
210216
# down-scope credentials with only IAM login scope (refreshes them too)

google/cloud/sql/proto/cloud_sql_metadata_exchange_pb2.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
_sym_db = _symbol_database.Default()
2727

2828

29+
from google.api import field_behavior_pb2 as google_dot_api_dot_field__behavior__pb2
30+
2931
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
30-
b'\n:google/cloud/sql/v1beta4/cloud_sql_metadata_exchange.proto\x12\x18google.cloud.sql.v1beta4"\x90\x02\n\x16\x43loudSQLConnectRequest\x12\x14\n\x07version\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x17\n\nuser_agent\x18\x02 \x01(\tH\x01\x88\x01\x01\x12Y\n\rprotocol_type\x18\x03 \x01(\x0e\x32=.google.cloud.sql.v1beta4.CloudSQLConnectRequest.ProtocolTypeH\x02\x88\x01\x01"?\n\x0cProtocolType\x12\x1d\n\x19PROTOCOL_TYPE_UNSPECIFIED\x10\x00\x12\x07\n\x03TCP\x10\x01\x12\x07\n\x03UDS\x10\x02\x42\n\n\x08_versionB\r\n\x0b_user_agentB\x10\n\x0e_protocol_type"\x8d\x01\n\x17\x43loudSQLConnectResponse\x12\x42\n\rresponse_code\x18\x01 \x01(\x0e\x32&.google.cloud.sql.v1beta4.ResponseCodeH\x00\x88\x01\x01\x12\x12\n\x05\x65rror\x18\x02 \x01(\tH\x01\x88\x01\x01\x42\x10\n\x0e_response_codeB\x08\n\x06_error*@\n\x0cResponseCode\x12\x1d\n\x19RESPONSE_CODE_UNSPECIFIED\x10\x00\x12\x06\n\x02OK\x10\x01\x12\t\n\x05\x45RROR\x10\x02\x62\x06proto3'
32+
b'\n:google/cloud/sql/v1beta4/cloud_sql_metadata_exchange.proto\x12\x18google.cloud.sql.v1beta4\x1a\x1fgoogle/api/field_behavior.proto"\xc8\x01\n\x16\x43loudSQLConnectRequest\x12\x17\n\nuser_agent\x18\x01 \x01(\tB\x03\xe0\x41\x01\x12T\n\rprotocol_type\x18\x02 \x01(\x0e\x32=.google.cloud.sql.v1beta4.CloudSQLConnectRequest.ProtocolType"?\n\x0cProtocolType\x12\x1d\n\x19PROTOCOL_TYPE_UNSPECIFIED\x10\x00\x12\x07\n\x03TCP\x10\x01\x12\x07\n\x03UDS\x10\x02"\xc6\x01\n\x17\x43loudSQLConnectResponse\x12U\n\rresponse_code\x18\x01 \x01(\x0e\x32>.google.cloud.sql.v1beta4.CloudSQLConnectResponse.ResponseCode\x12\x12\n\x05\x65rror\x18\x02 \x01(\tB\x03\xe0\x41\x01"@\n\x0cResponseCode\x12\x1d\n\x19RESPONSE_CODE_UNSPECIFIED\x10\x00\x12\x06\n\x02OK\x10\x01\x12\t\n\x05\x45RROR\x10\x02\x62\x06proto3'
3133
)
3234

3335
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -37,12 +39,18 @@
3739
if _descriptor._USE_C_DESCRIPTORS == False:
3840

3941
DESCRIPTOR._options = None
40-
_RESPONSECODE._serialized_start = 507
41-
_RESPONSECODE._serialized_end = 571
42-
_CLOUDSQLCONNECTREQUEST._serialized_start = 89
43-
_CLOUDSQLCONNECTREQUEST._serialized_end = 361
44-
_CLOUDSQLCONNECTREQUEST_PROTOCOLTYPE._serialized_start = 253
45-
_CLOUDSQLCONNECTREQUEST_PROTOCOLTYPE._serialized_end = 316
46-
_CLOUDSQLCONNECTRESPONSE._serialized_start = 364
47-
_CLOUDSQLCONNECTRESPONSE._serialized_end = 505
42+
_CLOUDSQLCONNECTREQUEST.fields_by_name["user_agent"]._options = None
43+
_CLOUDSQLCONNECTREQUEST.fields_by_name["user_agent"]._serialized_options = (
44+
b"\340A\001"
45+
)
46+
_CLOUDSQLCONNECTRESPONSE.fields_by_name["error"]._options = None
47+
_CLOUDSQLCONNECTRESPONSE.fields_by_name["error"]._serialized_options = b"\340A\001"
48+
_CLOUDSQLCONNECTREQUEST._serialized_start = 122
49+
_CLOUDSQLCONNECTREQUEST._serialized_end = 322
50+
_CLOUDSQLCONNECTREQUEST_PROTOCOLTYPE._serialized_start = 259
51+
_CLOUDSQLCONNECTREQUEST_PROTOCOLTYPE._serialized_end = 322
52+
_CLOUDSQLCONNECTRESPONSE._serialized_start = 325
53+
_CLOUDSQLCONNECTRESPONSE._serialized_end = 523
54+
_CLOUDSQLCONNECTRESPONSE_RESPONSECODE._serialized_start = 459
55+
_CLOUDSQLCONNECTRESPONSE_RESPONSECODE._serialized_end = 523
4856
# @@protoc_insertion_point(module_scope)

google/cloud/sql/proto/cloud_sql_metadata_exchange_pb2.pyi

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@ from google.protobuf import descriptor as _descriptor
2020
from google.protobuf import message as _message
2121
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
2222

23+
from google.api import field_behavior_pb2 as _field_behavior_pb2
24+
2325
DESCRIPTOR: _descriptor.FileDescriptor
24-
ERROR: ResponseCode
25-
OK: ResponseCode
26-
RESPONSE_CODE_UNSPECIFIED: ResponseCode
2726

2827
class CloudSQLConnectRequest(_message.Message):
29-
__slots__ = ["protocol_type", "user_agent", "version"]
28+
__slots__ = ["protocol_type", "user_agent"]
3029

3130
class ProtocolType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
3231
__slots__ = []
@@ -36,13 +35,10 @@ class CloudSQLConnectRequest(_message.Message):
3635
TCP: CloudSQLConnectRequest.ProtocolType
3736
UDS: CloudSQLConnectRequest.ProtocolType
3837
USER_AGENT_FIELD_NUMBER: _ClassVar[int]
39-
VERSION_FIELD_NUMBER: _ClassVar[int]
4038
protocol_type: CloudSQLConnectRequest.ProtocolType
4139
user_agent: str
42-
version: str
4340
def __init__(
4441
self,
45-
version: _Optional[str] = ...,
4642
user_agent: _Optional[str] = ...,
4743
protocol_type: _Optional[
4844
_Union[CloudSQLConnectRequest.ProtocolType, str]
@@ -51,15 +47,21 @@ class CloudSQLConnectRequest(_message.Message):
5147

5248
class CloudSQLConnectResponse(_message.Message):
5349
__slots__ = ["error", "response_code"]
50+
51+
class ResponseCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
52+
__slots__ = []
53+
54+
ERROR: CloudSQLConnectResponse.ResponseCode
5455
ERROR_FIELD_NUMBER: _ClassVar[int]
56+
OK: CloudSQLConnectResponse.ResponseCode
5557
RESPONSE_CODE_FIELD_NUMBER: _ClassVar[int]
58+
RESPONSE_CODE_UNSPECIFIED: CloudSQLConnectResponse.ResponseCode
5659
error: str
57-
response_code: ResponseCode
60+
response_code: CloudSQLConnectResponse.ResponseCode
5861
def __init__(
5962
self,
60-
response_code: _Optional[_Union[ResponseCode, str]] = ...,
63+
response_code: _Optional[
64+
_Union[CloudSQLConnectResponse.ResponseCode, str]
65+
] = ...,
6166
error: _Optional[str] = ...,
6267
) -> None: ...
63-
64-
class ResponseCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
65-
__slots__ = []

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ dependencies = [
4747
"dnspython>=2.0.0",
4848
"Requests",
4949
"google-auth>=2.28.0",
50+
"protobuf",
5051
]
5152
dynamic = ["version"]
5253

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@ cryptography==44.0.2
44
dnspython==2.7.0
55
Requests==2.32.3
66
google-auth==2.38.0
7+
protobuf==6.30.0

tests/conftest.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from unit.mocks import create_ssl_context # type: ignore
2929
from unit.mocks import FakeCredentials # type: ignore
3030
from unit.mocks import FakeCSQLInstance # type: ignore
31+
from unit.mocks import metadata_exchange
3132

3233
from google.cloud.sql.connector.client import CloudSQLClient
3334
from google.cloud.sql.connector.connection_name import ConnectionName
@@ -84,10 +85,11 @@ def fake_credentials() -> FakeCredentials:
8485
return FakeCredentials()
8586

8687

87-
async def start_proxy_server(instance: FakeCSQLInstance) -> None:
88-
"""Run local proxy server capable of performing mTLS"""
88+
async def start_proxy_server(
89+
instance: FakeCSQLInstance, port: int = 3307, use_metadata: bool = True
90+
) -> None:
91+
"""Run local proxy server capable of performing mTLS and metadata exchange"""
8992
ip_address = "127.0.0.1"
90-
port = 3307
9193
# create socket
9294
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
9395
# create SSL/TLS context
@@ -116,12 +118,15 @@ async def start_proxy_server(instance: FakeCSQLInstance) -> None:
116118
with context.wrap_socket(sock, server_side=True) as ssock:
117119
while True:
118120
conn, _ = ssock.accept()
121+
if use_metadata:
122+
metadata_exchange(conn)
123+
conn.sendall(instance.name.encode("utf-8"))
119124
conn.close()
120125

121126

122127
@pytest.fixture(scope="session")
123-
def proxy_server(fake_instance: FakeCSQLInstance) -> None:
124-
"""Run local proxy server capable of performing mTLS"""
128+
def proxy_server_with_metadata(fake_instance: FakeCSQLInstance) -> None:
129+
"""Run local proxy server capable of performing metadata exchange"""
125130
thread = Thread(
126131
target=asyncio.run,
127132
args=(
@@ -135,6 +140,18 @@ def proxy_server(fake_instance: FakeCSQLInstance) -> None:
135140
thread.join(1.0) # add a delay to allow the proxy server to start
136141

137142

143+
@pytest.fixture(scope="session")
144+
def proxy_server(fake_instance: FakeCSQLInstance) -> None:
145+
"""Run local proxy server capable of performing mTLS"""
146+
thread = Thread(
147+
target=asyncio.run,
148+
args=(start_proxy_server(fake_instance, 3308, False),),
149+
daemon=True,
150+
)
151+
thread.start()
152+
thread.join(1.0) # add a delay to allow the proxy server to start
153+
154+
138155
@pytest.fixture
139156
async def context(fake_instance: FakeCSQLInstance) -> ssl.SSLContext:
140157
return await create_ssl_context(fake_instance)

tests/unit/mocks.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import datetime
2222
import json
2323
import ssl
24+
import struct
2425
from typing import Any, Callable, Literal, Optional
2526

2627
from aiofiles.tempfile import TemporaryDirectory
@@ -38,6 +39,7 @@
3839
from google.cloud.sql.connector.connector import _DEFAULT_UNIVERSE_DOMAIN
3940
from google.cloud.sql.connector.utils import generate_keys
4041
from google.cloud.sql.connector.utils import write_to_file
42+
import google.cloud.sql.proto.cloud_sql_metadata_exchange_pb2 as connectorspb
4143

4244

4345
class FakeCredentials:
@@ -298,3 +300,65 @@ async def generate_ephemeral(self, request: Any) -> web.Response:
298300
}
299301
}
300302
return web.Response(content_type="application/json", body=json.dumps(response))
303+
304+
305+
def metadata_exchange(sock: ssl.SSLSocket) -> None:
306+
"""
307+
Mimics server side metadata exchange behavior in four steps:
308+
309+
1. Read a big endian uint32 (4 bytes) from the client. This is the number of
310+
bytes the message consumes. The length does not include the initial four
311+
bytes.
312+
313+
2. Read the message from the client using the message length and serialize
314+
it into a MetadataExchangeResponse message.
315+
316+
The real server implementation will then validate the client has connection
317+
permissions using the provided OAuth2 token based on the auth type. Here in
318+
the test implementation, the server does nothing.
319+
320+
3. Prepare a response and write the size of the response as a big endian
321+
uint32 (4 bytes)
322+
323+
4. Parse the response to bytes and write those to the client as well.
324+
325+
Subsequent interactions with the test server use the database protocol.
326+
"""
327+
# read metadata message length (4 bytes)
328+
message_len_buffer_size = struct.Struct("I").size
329+
message_len_buffer = b""
330+
while message_len_buffer_size > 0:
331+
chunk = sock.recv(message_len_buffer_size)
332+
if not chunk:
333+
raise RuntimeError(
334+
"Connection closed while getting metadata exchange length!"
335+
)
336+
message_len_buffer += chunk
337+
message_len_buffer_size -= len(chunk)
338+
339+
(message_len,) = struct.unpack(">I", message_len_buffer)
340+
341+
# read metadata exchange message
342+
buffer = b""
343+
while message_len > 0:
344+
chunk = sock.recv(message_len)
345+
if not chunk:
346+
raise RuntimeError("Connection closed while performing metadata exchange!")
347+
buffer += chunk
348+
message_len -= len(chunk)
349+
350+
# form metadata exchange request to be received from client
351+
message = connectorspb.CloudSQLConnectRequest()
352+
# parse metadata exchange request from buffer
353+
message.ParseFromString(buffer)
354+
355+
# form metadata exchange response to send to client
356+
resp = connectorspb.CloudSQLConnectResponse(
357+
response_code=connectorspb.CloudSQLConnectResponse.OK
358+
)
359+
360+
# pack big-endian unsigned integer (4 bytes)
361+
resp_len = struct.pack(">I", resp.ByteSize())
362+
363+
# send metadata response length and response message
364+
sock.sendall(resp_len + resp.SerializeToString())

tests/unit/test_connector.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,32 @@ def test_Connector_connect_bad_ip_type(
257257
)
258258

259259

260+
@pytest.mark.usefixtures("proxy_server_with_metadata")
261+
async def test_connect(
262+
fake_credentials: Credentials, fake_client: CloudSQLClient
263+
) -> None:
264+
"""
265+
Test that connector.connect returns connection object.
266+
"""
267+
client = fake_client
268+
async with Connector(
269+
credentials=fake_credentials, loop=asyncio.get_running_loop()
270+
) as connector:
271+
connector._client = client
272+
# patch db connection creation
273+
with patch("google.cloud.sql.connector.pg8000.connect") as mock_connect:
274+
mock_connect.return_value = True
275+
connection = await connector.connect_async(
276+
"test-project:test-region:test-instance",
277+
"pg8000",
278+
user="my-user",
279+
password="my-pass",
280+
db="my-db",
281+
)
282+
# check connection is returned
283+
assert connection is True
284+
285+
260286
@pytest.mark.asyncio
261287
async def test_Connector_connect_async(
262288
fake_credentials: Credentials, fake_client: CloudSQLClient

tests/unit/test_monitored_cache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ async def test_MonitoredCache_check_domain_name(
180180
# configure a local socket
181181
ip_addr = "127.0.0.1"
182182
sock = context.wrap_socket(
183-
socket.create_connection((ip_addr, 3307)),
183+
socket.create_connection((ip_addr, 3308)),
184184
server_hostname=ip_addr,
185185
)
186186
# verify socket is open
@@ -218,7 +218,7 @@ async def test_MonitoredCache_purge_closed_sockets(
218218
# configure a local socket
219219
ip_addr = "127.0.0.1"
220220
sock = context.wrap_socket(
221-
socket.create_connection((ip_addr, 3307)),
221+
socket.create_connection((ip_addr, 3308)),
222222
server_hostname=ip_addr,
223223
)
224224

0 commit comments

Comments
 (0)