Skip to content

Commit 0a1ca17

Browse files
chore: merge main
2 parents e8702a2 + 15934bd commit 0a1ca17

File tree

5 files changed

+99
-14
lines changed

5 files changed

+99
-14
lines changed

google/cloud/sql/connector/connection_name.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
# Additionally, we have to support legacy "domain-scoped" projects
2020
# (e.g. "google.com:PROJECT")
2121
CONN_NAME_REGEX = re.compile(("([^:]+(:[^:]+)?):([^:]+):([^:]+)"))
22+
# The domain name pattern in accordance with RFC 1035, RFC 1123 and RFC 2181.
23+
DOMAIN_NAME_REGEX = re.compile(
24+
r"^(?:[_a-z0-9](?:[_a-z0-9-]{0,61}[a-z0-9])?\.)+(?:[a-z](?:[a-z0-9-]{0,61}[a-z0-9])?)?$"
25+
)
2226

2327

2428
@dataclass
@@ -43,6 +47,12 @@ def get_connection_string(self) -> str:
4347
return f"{self.project}:{self.region}:{self.instance_name}"
4448

4549

50+
def _is_valid_domain(domain_name: str) -> bool:
51+
if DOMAIN_NAME_REGEX.fullmatch(domain_name) is None:
52+
return False
53+
return True
54+
55+
4656
def _parse_connection_name(connection_name: str) -> ConnectionName:
4757
return _parse_connection_name_with_domain_name(connection_name, "")
4858

google/cloud/sql/connector/resolver.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from google.cloud.sql.connector.connection_name import (
1818
_parse_connection_name_with_domain_name,
1919
)
20+
from google.cloud.sql.connector.connection_name import _is_valid_domain
2021
from google.cloud.sql.connector.connection_name import _parse_connection_name
2122
from google.cloud.sql.connector.connection_name import ConnectionName
2223
from google.cloud.sql.connector.exceptions import DnsResolutionError
@@ -40,8 +41,16 @@ async def resolve(self, dns: str) -> ConnectionName: # type: ignore
4041
conn_name = _parse_connection_name(dns)
4142
except ValueError:
4243
# The connection name was not project:region:instance format.
43-
# Attempt to query a TXT record to get connection name.
44-
conn_name = await self.query_dns(dns)
44+
# Check if connection name is a valid DNS domain name
45+
if _is_valid_domain(dns):
46+
# Attempt to query a TXT record to get connection name.
47+
conn_name = await self.query_dns(dns)
48+
else:
49+
raise ValueError(
50+
"Arg `instance_connection_string` must have "
51+
"format: PROJECT:REGION:INSTANCE or be a valid DNS domain "
52+
f"name, got {dns}."
53+
)
4554
return conn_name
4655

4756
async def query_dns(self, dns: str) -> ConnectionName:

tests/system/test_asyncpg_connection.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616

1717
import asyncio
1818
import os
19-
from typing import Any
19+
from typing import Any, Union
2020

2121
import asyncpg
2222
import sqlalchemy
2323
import sqlalchemy.ext.asyncio
2424

2525
from google.cloud.sql.connector import Connector
26+
from google.cloud.sql.connector import DefaultResolver
27+
from google.cloud.sql.connector import DnsResolver
2628

2729

2830
async def create_sqlalchemy_engine(
@@ -31,6 +33,7 @@ async def create_sqlalchemy_engine(
3133
password: str,
3234
db: str,
3335
refresh_strategy: str = "background",
36+
resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver,
3437
) -> tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, Connector]:
3538
"""Creates a connection pool for a Cloud SQL instance and returns the pool
3639
and the connector. Callers are responsible for closing the pool and the
@@ -64,9 +67,16 @@ async def create_sqlalchemy_engine(
6467
Refresh strategy for the Cloud SQL Connector. Can be one of "lazy"
6568
or "background". For serverless environments use "lazy" to avoid
6669
errors resulting from CPU being throttled.
70+
resolver (Optional[google.cloud.sql.connector.DefaultResolver]):
71+
Resolver class for resolving instance connection name. Use
72+
google.cloud.sql.connector.DnsResolver when resolving DNS domain
73+
names or google.cloud.sql.connector.DefaultResolver for regular
74+
instance connection names ("my-project:my-region:my-instance").
6775
"""
6876
loop = asyncio.get_running_loop()
69-
connector = Connector(loop=loop, refresh_strategy=refresh_strategy)
77+
connector = Connector(
78+
loop=loop, refresh_strategy=refresh_strategy, resolver=resolver
79+
)
7080

7181
async def getconn() -> asyncpg.Connection:
7282
conn: asyncpg.Connection = await connector.connect_async(
@@ -183,6 +193,24 @@ async def test_lazy_sqlalchemy_connection_with_asyncpg() -> None:
183193
await connector.close_async()
184194

185195

196+
async def test_custom_SAN_with_dns_sqlalchemy_connection_with_asyncpg() -> None:
197+
"""Basic test to get time from database."""
198+
inst_conn_name = os.environ["POSTGRES_CUSTOMER_CAS_PASS_VALID_DOMAIN_NAME"]
199+
user = os.environ["POSTGRES_USER"]
200+
password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"]
201+
db = os.environ["POSTGRES_DB"]
202+
203+
pool, connector = await create_sqlalchemy_engine(
204+
inst_conn_name, user, password, db, resolver=DnsResolver
205+
)
206+
207+
async with pool.connect() as conn:
208+
res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone()
209+
assert res[0] == 1
210+
211+
await connector.close_async()
212+
213+
186214
async def test_connection_with_asyncpg() -> None:
187215
"""Basic test to get time from database."""
188216
inst_conn_name = os.environ["POSTGRES_CONNECTION_NAME"]

tests/system/test_pg8000_connection.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def create_sqlalchemy_engine(
3434
password: str,
3535
db: str,
3636
refresh_strategy: str = "background",
37-
resolver: Union[DefaultResolver, DnsResolver] = DefaultResolver,
37+
resolver: Union[type[DefaultResolver], type[DnsResolver]] = DefaultResolver,
3838
) -> tuple[sqlalchemy.engine.Engine, Connector]:
3939
"""Creates a connection pool for a Cloud SQL instance and returns the pool
4040
and the connector. Callers are responsible for closing the pool and the
@@ -69,11 +69,11 @@ def create_sqlalchemy_engine(
6969
Refresh strategy for the Cloud SQL Connector. Can be one of "lazy"
7070
or "background". For serverless environments use "lazy" to avoid
7171
errors resulting from CPU being throttled.
72-
resolver (Optional[google.cloud.sql.connector.DefaultResolver | google.cloud.sql.connector.DnsResolver])
73-
Resolver class for the Cloud SQL Connector. Can be one of
74-
DefaultResolver (default) or DnsResolver. The resolver tells the
75-
connector whether to resolve the 'instance_connection_name' as a
76-
Cloud SQL instance connection name or as a domain name.
72+
resolver (Optional[google.cloud.sql.connector.DefaultResolver]):
73+
Resolver class for resolving instance connection name. Use
74+
google.cloud.sql.connector.DnsResolver when resolving DNS domain
75+
names or google.cloud.sql.connector.DefaultResolver for regular
76+
instance connection names ("my-project:my-region:my-instance").
7777
"""
7878
connector = Connector(refresh_strategy=refresh_strategy, resolver=resolver)
7979

@@ -165,15 +165,15 @@ def test_customer_managed_CAS_pg8000_connection() -> None:
165165
connector.close()
166166

167167

168-
def test_domain_name_pg8000_connection() -> None:
169-
"""Basic test to get time from database using domain name to connect."""
170-
domain_name = os.environ["POSTGRES_CUSTOMER_CAS_DOMAIN_NAME"]
168+
def test_custom_SAN_with_dns_pg8000_connection() -> None:
169+
"""Basic test to get time from database."""
170+
inst_conn_name = os.environ["POSTGRES_CUSTOMER_CAS_DOMAIN_NAME"]
171171
user = os.environ["POSTGRES_USER"]
172172
password = os.environ["POSTGRES_CUSTOMER_CAS_PASS"]
173173
db = os.environ["POSTGRES_DB"]
174174

175175
engine, connector = create_sqlalchemy_engine(
176-
domain_name, user, password, db, "lazy", DnsResolver
176+
inst_conn_name, user, password, db, resolver=DnsResolver
177177
)
178178
with engine.connect() as conn:
179179
time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone()

tests/unit/test_connection_name.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from google.cloud.sql.connector.connection_name import (
1818
_parse_connection_name_with_domain_name,
1919
)
20+
from google.cloud.sql.connector.connection_name import _is_valid_domain
2021
from google.cloud.sql.connector.connection_name import _parse_connection_name
2122
from google.cloud.sql.connector.connection_name import ConnectionName
2223

@@ -100,3 +101,40 @@ def test_parse_connection_name_with_domain_name(
100101
assert expected == _parse_connection_name_with_domain_name(
101102
connection_name, domain_name
102103
)
104+
105+
106+
@pytest.mark.parametrize(
107+
"domain_name, expected",
108+
[
109+
(
110+
"prod-db.mycompany.example.com",
111+
True,
112+
),
113+
(
114+
"example.com.", # trailing dot
115+
True,
116+
),
117+
(
118+
"-example.com.", # leading hyphen
119+
False,
120+
),
121+
(
122+
"example", # missing TLD
123+
False,
124+
),
125+
(
126+
"127.0.0.1", # IPv4 address
127+
False,
128+
),
129+
(
130+
"0:0:0:0:0:0:0:1", # IPv6 address
131+
False,
132+
),
133+
],
134+
)
135+
def test_is_valid_domain(domain_name: str, expected: bool) -> None:
136+
"""
137+
Test that _is_valid_domain works correctly for
138+
parsing domain names.
139+
"""
140+
assert expected == _is_valid_domain(domain_name)

0 commit comments

Comments
 (0)