Skip to content

Commit b3f80c9

Browse files
authored
feat: allow specifying ip address type (#79)
* feat: allow specifying ip address type as public and/or private when connecting * make IPTypes enum public * address review comments * set default IP type to public * add tests for specifying IP types and flag to skip private IP * Add info about specifying IP types to README * linting fixes * allow only 1 IP type to be passed into connection args * update README * hard code default IP type
1 parent 41b11db commit b3f80c9

File tree

7 files changed

+182
-25
lines changed

7 files changed

+182
-25
lines changed

README.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,19 @@ connector.connect(
5656
```
5757
**Note for SQL Server users**: If your SQL Server instance requires SSL, you need to download the CA certificate for your instance and include `cafile={path to downloaded certificate}` and `validate_host=False`. This is a workaround for a [known issue](https://issuetracker.google.com/184867147).
5858

59+
### Specifying Public or Private IP
60+
The Cloud SQL Connector for Python can be used to connect to Cloud SQL instances using both public and private IP addresses. To specify which IP address to use to connect, set the `ip_type` keyword argument Possible values are `IPTypes.PUBLIC` and `IPTypes.PRIVATE`.
61+
Example:
62+
```
63+
connector.connect(
64+
"your:connection:string:",
65+
"pymysql",
66+
ip_types=IPTypes.PRIVATE # Prefer private IP
67+
... insert other kwargs ...
68+
)
69+
```
70+
71+
Note: If specifying Private IP, your application must already be in the same VPC network as your Cloud SQL Instance.
5972
### Setup for development
6073

6174
Tests can be run with `nox`. Change directory into the `cloud-sql-python-connector` and just run `nox` to run the tests.

google/cloud/sql/connector/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from typing import Iterable
1818

1919
from .connector import connect
20-
from .instance_connection_manager import CloudSQLConnectionError
20+
from .instance_connection_manager import CloudSQLConnectionError, IPTypes
2121

2222

23-
__ALL__ = [connect, CloudSQLConnectionError]
23+
__ALL__ = [connect, CloudSQLConnectionError, IPTypes]
2424

2525
try:
2626
import pkg_resources

google/cloud/sql/connector/connector.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,13 @@
1717
import concurrent
1818
from google.cloud.sql.connector.instance_connection_manager import (
1919
InstanceConnectionManager,
20+
IPTypes,
2021
)
2122
from google.cloud.sql.connector.utils import generate_keys
2223

2324
from threading import Thread
2425
from typing import Any, Dict, Optional
2526

26-
2727
# This thread is used to background processing
2828
_thread: Optional[Thread] = None
2929
_loop: Optional[asyncio.AbstractEventLoop] = None
@@ -48,7 +48,12 @@ def _get_keys(loop: asyncio.AbstractEventLoop) -> concurrent.futures.Future:
4848
return _keys
4949

5050

51-
def connect(instance_connection_string: str, driver: str, **kwargs: Any) -> Any:
51+
def connect(
52+
instance_connection_string: str,
53+
driver: str,
54+
ip_types: IPTypes = IPTypes.PUBLIC,
55+
**kwargs: Any
56+
) -> Any:
5257
"""Prepares and returns a database connection object and starts a
5358
background thread to refresh the certificates and metadata.
5459
@@ -59,6 +64,15 @@ def connect(instance_connection_string: str, driver: str, **kwargs: Any) -> Any:
5964
6065
Example: example-proj:example-region-us6:example-instance
6166
67+
:type driver: str
68+
:param: driver:
69+
A string representing the driver to connect with. Supported drivers are
70+
pymysql, pg8000, and pytds.
71+
72+
:type ip_types: IPTypes
73+
The IP type (public or private) used to connect. IP types
74+
can be either IPTypes.PUBLIC or IPTypes.PRIVATE.
75+
6276
:param kwargs:
6377
Pass in any driver-specific arguments needed to connect to the Cloud
6478
SQL instance.
@@ -86,10 +100,10 @@ def connect(instance_connection_string: str, driver: str, **kwargs: Any) -> Any:
86100
_instances[instance_connection_string] = icm
87101

88102
if "timeout" in kwargs:
89-
return icm.connect(driver, **kwargs)
103+
return icm.connect(driver, ip_types, **kwargs)
90104
elif "connect_timeout" in kwargs:
91105
timeout = kwargs["connect_timeout"]
92106
else:
93107
timeout = 30 # 30s
94108

95-
return icm.connect(driver, timeout, **kwargs)
109+
return icm.connect(driver, ip_types, timeout, **kwargs)

google/cloud/sql/connector/instance_connection_manager.py

Lines changed: 56 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import asyncio
2424
import aiohttp
2525
import concurrent
26+
from enum import Enum
2627
import google.auth
2728
from google.auth.credentials import Credentials
2829
import google.auth.transport.requests
@@ -33,6 +34,7 @@
3334
Any,
3435
Awaitable,
3536
Coroutine,
37+
Dict,
3638
Optional,
3739
TYPE_CHECKING,
3840
Union,
@@ -55,6 +57,11 @@
5557
_delay: int = 55 * 60
5658

5759

60+
class IPTypes(Enum):
61+
PUBLIC: str = "PRIMARY"
62+
PRIVATE: str = "PRIVATE"
63+
64+
5865
class ConnectionSSLContext(ssl.SSLContext):
5966
"""Subclass of ssl.SSLContext with added request_ssl attribute. This is
6067
required for compatibility with pg8000 driver.
@@ -65,18 +72,37 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
6572
super(ConnectionSSLContext, self).__init__(*args, **kwargs)
6673

6774

75+
class CloudSQLConnectionError(Exception):
76+
"""
77+
Raised when the provided connection string is not formatted
78+
correctly.
79+
"""
80+
81+
def __init__(self, *args: Any) -> None:
82+
super(CloudSQLConnectionError, self).__init__(self, *args)
83+
84+
85+
class CloudSQLIPTypeError(Exception):
86+
"""
87+
Raised when IP address for the preferred IP type is not found.
88+
"""
89+
90+
def __init__(self, *args: Any) -> None:
91+
super(CloudSQLIPTypeError, self).__init__(self, *args)
92+
93+
6894
class InstanceMetadata:
69-
ip_address: str
95+
ip_addrs: Dict[str, Any]
7096
context: ssl.SSLContext
7197

7298
def __init__(
7399
self,
74100
ephemeral_cert: str,
75-
ip_address: str,
101+
ip_addrs: Dict[str, Any],
76102
private_key: bytes,
77103
server_ca_cert: str,
78104
) -> None:
79-
self.ip_address = ip_address
105+
self.ip_addrs = ip_addrs
80106
self.context = ConnectionSSLContext()
81107

82108
# tmpdir and its contents are automatically deleted after the CA cert
@@ -89,15 +115,16 @@ def __init__(
89115
self.context.load_cert_chain(cert_filename, keyfile=key_filename)
90116
self.context.load_verify_locations(cafile=ca_filename)
91117

92-
93-
class CloudSQLConnectionError(Exception):
94-
"""
95-
Raised when the provided connection string is not formatted
96-
correctly.
97-
"""
98-
99-
def __init__(self, *args: Any) -> None:
100-
super(CloudSQLConnectionError, self).__init__(self, *args)
118+
def get_preferred_ip(self, ip_type: IPTypes) -> str:
119+
"""Returns the first IP address for the instance, according to the preference
120+
supplied by ip_type. If no IP addressess with the given preference are found,
121+
an error is raised."""
122+
if ip_type.value in self.ip_addrs:
123+
return self.ip_addrs[ip_type.value]
124+
raise CloudSQLIPTypeError(
125+
"Cloud SQL instance does not have any IP addresses matching "
126+
f"preference: {ip_type.value})"
127+
)
101128

102129

103130
class InstanceConnectionManager:
@@ -241,7 +268,7 @@ async def _get_instance_data(self) -> InstanceMetadata:
241268

242269
return InstanceMetadata(
243270
ephemeral_cert,
244-
metadata["ip_addresses"]["PRIMARY"],
271+
metadata["ip_addresses"],
245272
priv_key,
246273
metadata["server_ca_cert"],
247274
)
@@ -296,7 +323,13 @@ async def _schedule_refresh(self, delay: int) -> asyncio.Task:
296323

297324
return await self._perform_refresh()
298325

299-
def connect(self, driver: str, timeout: int, **kwargs: Any) -> Any:
326+
def connect(
327+
self,
328+
driver: str,
329+
ip_type: IPTypes,
330+
timeout: int,
331+
**kwargs: Any,
332+
) -> Any:
300333
"""A method that returns a DB-API connection to the database.
301334
302335
:type driver: str
@@ -310,7 +343,7 @@ def connect(self, driver: str, timeout: int, **kwargs: Any) -> Any:
310343
"""
311344

312345
connect_future: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
313-
self._connect(driver, **kwargs), self._loop
346+
self._connect(driver, ip_type, **kwargs), self._loop
314347
)
315348

316349
try:
@@ -321,7 +354,12 @@ def connect(self, driver: str, timeout: int, **kwargs: Any) -> Any:
321354
else:
322355
return connection
323356

324-
async def _connect(self, driver: str, **kwargs: Any) -> Any:
357+
async def _connect(
358+
self,
359+
driver: str,
360+
ip_type: IPTypes,
361+
**kwargs: Any,
362+
) -> Any:
325363
"""A method that returns a DB-API connection to the database.
326364
327365
:type driver: str
@@ -344,14 +382,15 @@ async def _connect(self, driver: str, **kwargs: Any) -> Any:
344382
}
345383

346384
instance_data: InstanceMetadata = await self._current
385+
ip_address: str = instance_data.get_preferred_ip(ip_type)
347386

348387
try:
349388
connector = connect_func[driver]
350389
except KeyError:
351390
raise KeyError("Driver {} is not supported.".format(driver))
352391

353392
connect_partial = partial(
354-
connector, instance_data.ip_address, instance_data.context, **kwargs
393+
connector, ip_address, instance_data.context, **kwargs
355394
)
356395

357396
return await self._loop.run_in_executor(None, connect_partial)
Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,36 @@
1515
"""
1616
import os
1717
import threading
18-
from typing import Generator
18+
from typing import Any, Generator
1919

2020
import asyncio
2121
import pytest # noqa F401 Needed to run the tests
2222

2323

24+
def pytest_addoption(parser: Any) -> None:
25+
parser.addoption(
26+
"--run_private_ip",
27+
action="store_true",
28+
default=False,
29+
help="run tests that need to be running in VPC network",
30+
)
31+
32+
33+
def pytest_configure(config: Any) -> None:
34+
config.addinivalue_line(
35+
"markers", "private_ip: mark test as requiring private IP access"
36+
)
37+
38+
39+
def pytest_collection_modifyitems(config: Any, items: Any) -> None:
40+
if config.getoption("--run_private_ip"):
41+
return
42+
skip_private_ip = pytest.mark.skip(reason="need --run_private_ip option to run")
43+
for item in items:
44+
if "private_ip" in item.keywords:
45+
item.add_marker(skip_private_ip)
46+
47+
2448
@pytest.fixture
2549
def async_loop() -> Generator:
2650
"""

tests/system/test_ip_types.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
""""
2+
Copyright 2021 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
import logging
17+
import os
18+
import uuid
19+
20+
import pymysql
21+
import pytest
22+
import sqlalchemy
23+
from google.cloud.sql.connector import connector, IPTypes
24+
25+
table_name = f"books_{uuid.uuid4().hex}"
26+
27+
28+
def init_connection_engine(ip_types: IPTypes) -> sqlalchemy.engine.Engine:
29+
def getconn() -> pymysql.connections.Connection:
30+
conn: pymysql.connections.Connection = connector.connect(
31+
os.environ["MYSQL_CONNECTION_NAME"],
32+
"pymysql",
33+
ip_types=ip_types,
34+
user=os.environ["MYSQL_USER"],
35+
password=os.environ["MYSQL_PASS"],
36+
db=os.environ["MYSQL_DB"],
37+
)
38+
return conn
39+
40+
engine = sqlalchemy.create_engine(
41+
"mysql+pymysql://",
42+
creator=getconn,
43+
)
44+
return engine
45+
46+
47+
def test_public_ip() -> None:
48+
try:
49+
pool = init_connection_engine(IPTypes.PUBLIC)
50+
except Exception as e:
51+
logging.exception("Failed to initialize pool with public IP", e)
52+
with pool.connect() as conn:
53+
conn.execute("SELECT 1")
54+
55+
56+
@pytest.mark.private_ip
57+
def test_private_ip() -> None:
58+
try:
59+
pool = init_connection_engine(IPTypes.PRIVATE)
60+
except Exception as e:
61+
logging.exception("Failed to initialize pool with private IP", e)
62+
with pool.connect() as conn:
63+
conn.execute("SELECT 1")

tests/unit/test_connector.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,9 @@ async def timeout_stub(*args: Any, **kwargs: Any) -> None:
4545
mock_instances[connect_string] = icm
4646
with patch.dict(connector._instances, mock_instances):
4747
pytest.raises(
48-
TimeoutError, connector.connect, connect_string, "pymysql", timeout=timeout
48+
TimeoutError,
49+
connector.connect,
50+
connect_string,
51+
"pymysql",
52+
timeout=timeout,
4953
)

0 commit comments

Comments
 (0)