Skip to content

Commit e8702a2

Browse files
chore: move socket initialization to Connector level
1 parent 6f6d5e4 commit e8702a2

File tree

4 files changed

+75
-26
lines changed

4 files changed

+75
-26
lines changed

google/cloud/sql/connector/connection_info.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import abc
1718
from dataclasses import dataclass
1819
import logging
1920
import ssl
@@ -34,6 +35,27 @@
3435
logger = logging.getLogger(name=__name__)
3536

3637

38+
class ConnectionInfoCache(abc.ABC):
39+
"""Abstract class for Connector connection info caches."""
40+
41+
@abc.abstractmethod
42+
async def connect_info(self) -> ConnectionInfo:
43+
pass
44+
45+
@abc.abstractmethod
46+
async def force_refresh(self) -> None:
47+
pass
48+
49+
@abc.abstractmethod
50+
async def close(self) -> None:
51+
pass
52+
53+
@property
54+
@abc.abstractmethod
55+
def closed(self) -> bool:
56+
pass
57+
58+
3759
@dataclass
3860
class ConnectionInfo:
3961
"""Contains all necessary information to connect securely to the

google/cloud/sql/connector/connector.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from functools import partial
2121
import logging
2222
import os
23+
import socket
2324
from threading import Thread
2425
from types import TracebackType
2526
from typing import Any, Optional, Union
@@ -47,6 +48,7 @@
4748
logger = logging.getLogger(name=__name__)
4849

4950
ASYNC_DRIVERS = ["asyncpg"]
51+
SERVER_PROXY_PORT = 3307
5052
_DEFAULT_SCHEME = "https://"
5153
_DEFAULT_UNIVERSE_DOMAIN = "googleapis.com"
5254
_SQLADMIN_HOST_TEMPLATE = "sqladmin.{universe_domain}"
@@ -291,10 +293,11 @@ async def connect_async(
291293
driver=driver,
292294
)
293295
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
294-
if (instance_connection_string, enable_iam_auth) in self._cache:
295-
monitored_cache = self._cache[(instance_connection_string, enable_iam_auth)]
296+
297+
conn_name = await self._resolver.resolve(instance_connection_string)
298+
if (str(conn_name), enable_iam_auth) in self._cache:
299+
monitored_cache = self._cache[(str(conn_name), enable_iam_auth)]
296300
else:
297-
conn_name = await self._resolver.resolve(instance_connection_string)
298301
if self._refresh_strategy == RefreshStrategy.LAZY:
299302
logger.debug(
300303
f"['{conn_name}']: Refresh strategy is set to lazy refresh"
@@ -322,7 +325,7 @@ async def connect_async(
322325
self._resolver,
323326
)
324327
logger.debug(f"['{conn_name}']: Connection info added to cache")
325-
self._cache[(instance_connection_string, enable_iam_auth)] = monitored_cache
328+
self._cache[(str(conn_name), enable_iam_auth)] = monitored_cache
326329

327330
connect_func = {
328331
"pymysql": pymysql.connect,
@@ -358,7 +361,7 @@ async def connect_async(
358361
except Exception:
359362
# with an error from Cloud SQL Admin API call or IP type, invalidate
360363
# the cache and re-raise the error
361-
await self._remove_cached(instance_connection_string, enable_iam_auth)
364+
await self._remove_cached(str(conn_name), enable_iam_auth)
362365
raise
363366
logger.debug(f"['{conn_info.conn_name}']: Connecting to {ip_address}:3307")
364367
# format `user` param for automatic IAM database authn
@@ -379,11 +382,21 @@ async def connect_async(
379382
await conn_info.create_ssl_context(enable_iam_auth),
380383
**kwargs,
381384
)
382-
# synchronous drivers are blocking and run using executor
385+
# Create socket with SSLContext for sync drivers
386+
ctx = await conn_info.create_ssl_context(enable_iam_auth)
387+
sock = ctx.wrap_socket(
388+
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
389+
server_hostname=ip_address,
390+
)
391+
# If this connection was opened using a domain name, then store it
392+
# for later in case we need to forcibly close it on failover.
393+
if conn_info.conn_name.domain_name:
394+
monitored_cache.sockets.append(sock)
395+
# Synchronous drivers are blocking and run using executor
383396
connect_partial = partial(
384397
connector,
385398
ip_address,
386-
await conn_info.create_ssl_context(enable_iam_auth),
399+
sock,
387400
**kwargs,
388401
)
389402
return await self._loop.run_in_executor(None, connect_partial)
@@ -468,6 +481,7 @@ async def create_async_connector(
468481
universe_domain: Optional[str] = None,
469482
refresh_strategy: str | RefreshStrategy = RefreshStrategy.BACKGROUND,
470483
resolver: type[DefaultResolver] | type[DnsResolver] = DefaultResolver,
484+
failover_period: int = 30,
471485
) -> Connector:
472486
"""Helper function to create Connector object for asyncio connections.
473487
@@ -519,6 +533,11 @@ async def create_async_connector(
519533
DnsResolver.
520534
Default: DefaultResolver
521535
536+
failover_period (int): The time interval in seconds between each
537+
attempt to check if a failover has occured for a given instance.
538+
Must be used with `resolver=DnsResolver` to have any effect.
539+
Default: 30
540+
522541
Returns:
523542
A Connector instance configured with running event loop.
524543
"""
@@ -537,4 +556,5 @@ async def create_async_connector(
537556
universe_domain=universe_domain,
538557
refresh_strategy=refresh_strategy,
539558
resolver=resolver,
559+
failover_period=failover_period,
540560
)

google/cloud/sql/connector/monitored_cache.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import asyncio
1616
import logging
17+
import ssl
1718
from typing import Any, Callable, Optional, Union
1819

1920
from google.cloud.sql.connector.connection_info import ConnectionInfo
@@ -36,7 +37,7 @@ def __init__(
3637
self.resolver = resolver
3738
self.cache = cache
3839
self.domain_name_ticker: Optional[asyncio.Task] = None
39-
self.open_conns: int = 0
40+
self.sockets: list[ssl.SSLSocket] = []
4041

4142
if self.cache.conn_name.domain_name:
4243
self.domain_name_ticker = asyncio.create_task(
@@ -51,6 +52,15 @@ def __init__(
5152
def closed(self) -> bool:
5253
return self.cache.closed
5354

55+
async def _purge_closed_sockets(self) -> None:
56+
open_sockets = []
57+
for socket in self.sockets:
58+
# Check fileno as method to check if socket is closed. Will return
59+
# -1 on failure, which will be used to signal socket closed.
60+
if socket.fileno() != -1:
61+
open_sockets.append(socket)
62+
self.sockets = open_sockets
63+
5464
async def _check_domain_name(self) -> None:
5565
try:
5666
# Resolve domain name and see if Cloud SQL instance connection name
@@ -66,12 +76,6 @@ async def _check_domain_name(self) -> None:
6676
"connections!"
6777
)
6878
await self.close()
69-
conn_info = await self.connect_info()
70-
if conn_info.sock:
71-
logger.debug(f"Socket type: {type(conn_info.sock)}")
72-
conn_info.sock.close()
73-
else:
74-
logger.debug("Domain name mapping has not changed!")
7579

7680
except Exception as e:
7781
# Domain name checks should not be fatal, log error and continue.
@@ -97,10 +101,20 @@ async def close(self) -> None:
97101
logger.debug(
98102
f"['{self.cache.conn_name}']: Cancelled domain name polling task."
99103
)
100-
104+
finally:
105+
self.domain_name_ticker = None
101106
# If cache is already closed, no further work.
102-
if self.cache.closed:
107+
if self.closed:
103108
return
109+
110+
# Close any still open sockets
111+
for socket in self.sockets:
112+
# Check fileno as method to check if socket is closed. Will return
113+
# -1 on failure, which will be used to signal socket closed.
114+
if socket.fileno() != -1:
115+
socket.close()
116+
117+
# Close underyling ConnectionInfoCache
104118
await self.cache.close()
105119

106120

google/cloud/sql/connector/pg8000.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
limitations under the License.
1515
"""
1616

17-
import socket
1817
import ssl
1918
from typing import Any, TYPE_CHECKING
2019

@@ -25,16 +24,16 @@
2524

2625

2726
def connect(
28-
ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
27+
ip_address: str, sock: ssl.SSLSocket, **kwargs: Any
2928
) -> "pg8000.dbapi.Connection":
3029
"""Helper function to create a pg8000 DB-API connection object.
3130
3231
:type ip_address: str
3332
:param ip_address: A string containing an IP address for the Cloud SQL
3433
instance.
3534
36-
:type ctx: ssl.SSLContext
37-
:param ctx: An SSLContext object created from the Cloud SQL server CA
35+
:type sock: ssl.SSLSocket
36+
:param sock: An SSLSocket object created from the Cloud SQL server CA
3837
cert and ephemeral cert.
3938
4039
@@ -48,12 +47,6 @@ def connect(
4847
'Unable to import module "pg8000." Please install and try again.'
4948
)
5049

51-
# Create socket and wrap with context.
52-
sock = ctx.wrap_socket(
53-
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
54-
server_hostname=ip_address,
55-
)
56-
5750
user = kwargs.pop("user")
5851
db = kwargs.pop("db")
5952
passwd = kwargs.pop("password", None)

0 commit comments

Comments
 (0)