Skip to content

Commit bbd29ad

Browse files
committed
refactor: Use new ConnectSettings.DnsNames field to validate the server TLS certificate.
1 parent fb8c21c commit bbd29ad

File tree

4 files changed

+55
-5
lines changed

4 files changed

+55
-5
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,6 @@ venv
55
.python-version
66
cloud_sql_python_connector.egg-info/
77
dist/
8+
.idea
9+
.coverage
10+
sponge_log.xml

google/cloud/sql/connector/client.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,23 @@ async def _get_metadata(
156156
# resolve dnsName into IP address for PSC
157157
# Note that we have to check for PSC enablement also because CAS
158158
# instances also set the dnsName field.
159-
# Remove trailing period from DNS name. Required for SSL in Python
160-
dns_name = ret_dict.get("dnsName", "").rstrip(".")
161-
if dns_name and ret_dict.get("pscEnabled"):
162-
ip_addresses["PSC"] = dns_name
159+
if ret_dict.get("pscEnabled"):
160+
# Find PSC instance DNS name in the dns_names field
161+
psc_dns_names = [
162+
d["name"]
163+
for d in ret_dict.get("dnsNames", [])
164+
if d["connectionType"] == "PRIVATE_SERVICE_CONNECT"
165+
and d["dnsScope"] == "INSTANCE"
166+
]
167+
dns_name = psc_dns_names[0] if psc_dns_names else None
168+
169+
# Fall back do dns_name field if dns_names is not set
170+
if dns_name is None:
171+
dns_name = ret_dict.get("dnsName", None)
172+
173+
# Remove trailing period from DNS name. Required for SSL in Python
174+
if dns_name:
175+
ip_addresses["PSC"] = dns_name.rstrip(".")
163176

164177
return {
165178
"ip_addresses": ip_addresses,

tests/unit/mocks.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ def __init__(
225225
"PRIMARY": "127.0.0.1",
226226
"PRIVATE": "10.0.0.1",
227227
},
228+
legacy_dns_name: bool = False,
228229
cert_before: datetime = datetime.datetime.now(datetime.timezone.utc),
229230
cert_expiration: datetime = datetime.datetime.now(datetime.timezone.utc)
230231
+ datetime.timedelta(hours=1),
@@ -237,6 +238,7 @@ def __init__(
237238
self.psc_enabled = False
238239
self.cert_before = cert_before
239240
self.cert_expiration = cert_expiration
241+
self.legacy_dns_name = legacy_dns_name
240242
# create self signed CA cert
241243
self.server_ca, self.server_key = generate_cert(
242244
self.project, self.name, cert_before, cert_expiration
@@ -255,12 +257,22 @@ async def connect_settings(self, request: Any) -> web.Response:
255257
"instance": self.name,
256258
"expirationTime": str(self.cert_expiration),
257259
},
258-
"dnsName": "abcde.12345.us-central1.sql.goog",
259260
"pscEnabled": self.psc_enabled,
260261
"ipAddresses": ip_addrs,
261262
"region": self.region,
262263
"databaseVersion": self.db_version,
263264
}
265+
if self.legacy_dns_name:
266+
response["dnsName"] = "abcde.12345.us-central1.sql.goog"
267+
else:
268+
response["dnsNames"] = [
269+
{
270+
"name": "abcde.12345.us-central1.sql.goog",
271+
"connectionType": "PRIVATE_SERVICE_CONNECT",
272+
"dnsScope": "INSTANCE",
273+
}
274+
]
275+
264276
return web.Response(content_type="application/json", body=json.dumps(response))
265277

266278
async def generate_ephemeral(self, request: Any) -> web.Response:

tests/unit/test_client.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,28 @@ async def test_get_metadata_with_psc(fake_client: CloudSQLClient) -> None:
6565
assert isinstance(resp["server_ca_cert"], str)
6666

6767

68+
@pytest.mark.asyncio
69+
async def test_get_metadata_legacy_dns_with_psc(fake_client: CloudSQLClient) -> None:
70+
"""
71+
Test _get_metadata returns successfully with PSC IP type.
72+
"""
73+
# set PSC to enabled on test instance
74+
fake_client.instance.psc_enabled = True
75+
fake_client.instance.legacy_dns_name = True
76+
resp = await fake_client._get_metadata(
77+
"test-project",
78+
"test-region",
79+
"test-instance",
80+
)
81+
assert resp["database_version"] == "POSTGRES_15"
82+
assert resp["ip_addresses"] == {
83+
"PRIMARY": "127.0.0.1",
84+
"PRIVATE": "10.0.0.1",
85+
"PSC": "abcde.12345.us-central1.sql.goog",
86+
}
87+
assert isinstance(resp["server_ca_cert"], str)
88+
89+
6890
@pytest.mark.asyncio
6991
async def test_get_ephemeral(fake_client: CloudSQLClient) -> None:
7092
"""

0 commit comments

Comments
 (0)