Skip to content

Commit 4325146

Browse files
chore: improve test coverage and mocks (#339)
1 parent 3b40c00 commit 4325146

File tree

9 files changed

+771
-154
lines changed

9 files changed

+771
-154
lines changed

requirements-test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ twine==4.0.1
1616
PyMySQL==1.0.2
1717
pg8000==1.29.1
1818
python-tds==1.11.0
19+
aioresponses==0.7.3

tests/conftest.py

Lines changed: 121 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,21 @@
1414
limitations under the License.
1515
"""
1616
import os
17-
from typing import Any
17+
import socket
18+
import asyncio
19+
import pytest # noqa F401 Needed to run the tests
20+
21+
from threading import Thread
22+
from typing import Any, Generator, AsyncGenerator
1823
from google.auth.credentials import Credentials, with_scopes_if_required
1924
from google.oauth2 import service_account
25+
from aioresponses import aioresponses
26+
from mock import patch
2027

21-
import asyncio
22-
import pytest # noqa F401 Needed to run the tests
28+
from unit.mocks import FakeCSQLInstance # type: ignore
29+
from google.cloud.sql.connector import Connector
30+
from google.cloud.sql.connector.instance import Instance
31+
from google.cloud.sql.connector.utils import generate_keys
2332

2433
SCOPES = [
2534
"https://www.googleapis.com/auth/sqlservice.admin",
@@ -96,4 +105,113 @@ def fake_credentials() -> Credentials:
96105
fake_service_account
97106
)
98107
fake_credentials = with_scopes_if_required(fake_credentials, scopes=SCOPES)
108+
# stub refresh of credentials
109+
setattr(fake_credentials, "refresh", lambda *args: None)
99110
return fake_credentials
111+
112+
113+
def mock_server(server_sock: socket.socket) -> None:
114+
"""Create mock server listening on specified ip_address and port."""
115+
ip_address = "127.0.0.1"
116+
port = 3307
117+
server_sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
118+
server_sock.bind((ip_address, port))
119+
server_sock.listen(0)
120+
server_sock.accept()
121+
122+
123+
@pytest.fixture
124+
def server() -> Generator:
125+
"""Create thread with server listening on proper port"""
126+
server_sock = socket.socket()
127+
thread = Thread(target=mock_server, args=(server_sock,), daemon=True)
128+
thread.start()
129+
yield thread
130+
server_sock.close()
131+
thread.join()
132+
133+
134+
@pytest.fixture
135+
def kwargs() -> Any:
136+
"""Database connection keyword arguments."""
137+
kwargs = {"user": "test-user", "db": "test-db", "password": "test-password"}
138+
return kwargs
139+
140+
141+
@pytest.fixture(scope="module")
142+
def mock_instance() -> FakeCSQLInstance:
143+
mock_instance = FakeCSQLInstance("my-project", "my-region", "my-instance")
144+
return mock_instance
145+
146+
147+
@pytest.fixture
148+
async def instance(
149+
mock_instance: FakeCSQLInstance,
150+
fake_credentials: Credentials,
151+
event_loop: asyncio.AbstractEventLoop,
152+
) -> AsyncGenerator[Instance, None]:
153+
"""
154+
Instance with mocked API calls.
155+
"""
156+
# generate client key pair
157+
keys = asyncio.run_coroutine_threadsafe(generate_keys(), event_loop)
158+
key_task = asyncio.wrap_future(keys, loop=event_loop)
159+
_, client_key = await key_task
160+
with patch("google.auth.default") as mock_auth:
161+
mock_auth.return_value = fake_credentials, None
162+
# mock Cloud SQL Admin API calls
163+
with aioresponses() as mocked:
164+
mocked.get(
165+
"https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance/connectSettings",
166+
status=200,
167+
body=mock_instance.connect_settings(),
168+
repeat=True,
169+
)
170+
mocked.post(
171+
"https://sqladmin.googleapis.com/sql/v1beta4/projects/my-project/instances/my-instance:generateEphemeralCert",
172+
status=200,
173+
body=mock_instance.generate_ephemeral(client_key),
174+
repeat=True,
175+
)
176+
177+
instance = Instance(
178+
"my-project:my-region:my-instance", "pg8000", keys, event_loop
179+
)
180+
181+
yield instance
182+
await instance.close()
183+
184+
185+
@pytest.fixture
186+
async def connector(fake_credentials: Credentials) -> AsyncGenerator[Connector, None]:
187+
instance_connection_name = "my-project:my-region:my-instance"
188+
project, region, instance_name = instance_connection_name.split(":")
189+
# initialize connector
190+
connector = Connector()
191+
with patch("google.auth.default") as mock_auth:
192+
mock_auth.return_value = fake_credentials, None
193+
# mock Cloud SQL Admin API calls
194+
mock_instance = FakeCSQLInstance(project, region, instance_name)
195+
_, client_key = connector._keys.result()
196+
with aioresponses() as mocked:
197+
mocked.get(
198+
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance_name}/connectSettings",
199+
status=200,
200+
body=mock_instance.connect_settings(),
201+
repeat=True,
202+
)
203+
mocked.post(
204+
f"https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance_name}:generateEphemeralCert",
205+
status=200,
206+
body=mock_instance.generate_ephemeral(client_key),
207+
repeat=True,
208+
)
209+
# initialize Instance using mocked API calls
210+
instance = Instance(
211+
instance_connection_name, "pg8000", connector._keys, connector._loop
212+
)
213+
214+
connector._instances[instance_connection_name] = instance
215+
216+
yield connector
217+
connector.close()

tests/unit/mocks.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
""""
2+
Copyright 2022 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
# file containing all mocks used for Cloud SQL Python Connector unit tests
17+
18+
import json
19+
import ssl
20+
from tempfile import TemporaryDirectory
21+
from typing import Any, Dict, Tuple, Optional
22+
from google.cloud.sql.connector import IPTypes
23+
from google.cloud.sql.connector.instance import InstanceMetadata
24+
from google.cloud.sql.connector.utils import write_to_file, generate_keys
25+
import datetime
26+
from cryptography.hazmat.backends import default_backend
27+
from cryptography.hazmat.primitives import serialization, hashes
28+
from cryptography.hazmat.primitives.asymmetric import rsa
29+
from cryptography import x509
30+
from cryptography.x509.oid import NameOID
31+
32+
33+
class MockInstance:
34+
_enable_iam_auth: bool
35+
36+
def __init__(
37+
self,
38+
enable_iam_auth: bool = False,
39+
) -> None:
40+
self._enable_iam_auth = enable_iam_auth
41+
42+
# mock connect_info
43+
async def connect_info(
44+
self,
45+
driver: str,
46+
ip_type: IPTypes,
47+
**kwargs: Any,
48+
) -> Any:
49+
return True
50+
51+
52+
class BadRefresh(Exception):
53+
pass
54+
55+
56+
class MockMetadata(InstanceMetadata):
57+
"""Mock class for InstanceMetadata"""
58+
59+
def __init__(
60+
self, expiration: datetime.datetime, ip_addrs: Dict = {"PRIMARY": "0.0.0.0"}
61+
) -> None:
62+
self.expiration = expiration
63+
self.ip_addrs = ip_addrs
64+
65+
66+
async def instance_metadata_success(*args: Any, **kwargs: Any) -> MockMetadata:
67+
return MockMetadata(datetime.datetime.now() + datetime.timedelta(minutes=10))
68+
69+
70+
async def instance_metadata_expired(*args: Any, **kwargs: Any) -> MockMetadata:
71+
return MockMetadata(datetime.datetime.now() - datetime.timedelta(minutes=10))
72+
73+
74+
async def instance_metadata_error(*args: Any, **kwargs: Any) -> None:
75+
raise BadRefresh("something went wrong...")
76+
77+
78+
def generate_cert(
79+
project: str, name: str
80+
) -> Tuple[x509.CertificateBuilder, rsa.RSAPrivateKey]:
81+
"""
82+
Generate a private key and cert object to be used in testing.
83+
"""
84+
# generate private key
85+
key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
86+
common_name = f"{project}:{name}"
87+
# configure cert subject
88+
subject = issuer = x509.Name(
89+
[
90+
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
91+
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"),
92+
x509.NameAttribute(NameOID.LOCALITY_NAME, "Mountain View"),
93+
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Google Inc"),
94+
x509.NameAttribute(NameOID.COMMON_NAME, "{}".format(common_name)),
95+
]
96+
)
97+
# build cert
98+
cert = (
99+
x509.CertificateBuilder()
100+
.subject_name(subject)
101+
.issuer_name(issuer)
102+
.public_key(key.public_key())
103+
.serial_number(x509.random_serial_number())
104+
.not_valid_before(datetime.datetime.utcnow())
105+
.not_valid_after(
106+
# cert valid for 10 mins
107+
datetime.datetime.utcnow()
108+
+ datetime.timedelta(minutes=60)
109+
)
110+
)
111+
return cert, key
112+
113+
114+
def self_signed_cert(cert: x509.CertificateBuilder, key: rsa.RSAPrivateKey) -> str:
115+
"""
116+
Create a PEM encoded certificate that is self-signed.
117+
"""
118+
return (
119+
cert.sign(key, hashes.SHA256(), default_backend())
120+
.public_bytes(encoding=serialization.Encoding.PEM)
121+
.decode("UTF-8")
122+
)
123+
124+
125+
def client_key_signed_cert(
126+
cert: x509.CertificateBuilder,
127+
priv_key: rsa.RSAPrivateKey,
128+
client_key: rsa.RSAPublicKey,
129+
) -> str:
130+
"""
131+
Create a PEM encoded certificate that is signed by given public key.
132+
"""
133+
# configure cert subject
134+
subject = issuer = x509.Name(
135+
[
136+
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
137+
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Google Inc"),
138+
x509.NameAttribute(NameOID.COMMON_NAME, "Google Cloud SQL Client"),
139+
]
140+
)
141+
# build cert
142+
cert = (
143+
x509.CertificateBuilder()
144+
.subject_name(subject)
145+
.issuer_name(issuer)
146+
.public_key(client_key)
147+
.serial_number(x509.random_serial_number())
148+
.not_valid_before(datetime.datetime.utcnow())
149+
.not_valid_after(cert._not_valid_after) # type: ignore
150+
)
151+
return (
152+
cert.sign(priv_key, hashes.SHA256(), default_backend())
153+
.public_bytes(encoding=serialization.Encoding.PEM)
154+
.decode("UTF-8")
155+
)
156+
157+
158+
async def create_ssl_context() -> ssl.SSLContext:
159+
"""Helper method to build an ssl.SSLContext for tests"""
160+
# generate keys and certs for test
161+
cert, private_key = generate_cert("my-project", "my-instance")
162+
server_ca_cert = self_signed_cert(cert, private_key)
163+
client_private, client_bytes = await generate_keys()
164+
client_key: rsa.RSAPublicKey = serialization.load_pem_public_key(
165+
client_bytes.encode("UTF-8"), default_backend()
166+
) # type: ignore
167+
ephemeral_cert = client_key_signed_cert(cert, private_key, client_key)
168+
# build default ssl.SSLContext
169+
context = ssl.create_default_context()
170+
# load ssl.SSLContext with certs
171+
with TemporaryDirectory() as tmpdir:
172+
ca_filename, cert_filename, key_filename = write_to_file(
173+
tmpdir, server_ca_cert, ephemeral_cert, client_private
174+
)
175+
context.load_cert_chain(cert_filename, keyfile=key_filename)
176+
context.load_verify_locations(cafile=ca_filename)
177+
return context
178+
179+
180+
class FakeCSQLInstance:
181+
def __init__(self, project: str, region: str, name: str) -> None:
182+
self.project = project
183+
self.region = region
184+
self.name = name
185+
self.db_version = "POSTGRES_14" # arbitrary value
186+
self.ip_addrs = {"PRIMARY": "0.0.0.0", "PRIVATE": "1.1.1.1"}
187+
self.backend_type = "SECOND_GEN"
188+
189+
# generate server private key and cert
190+
cert, key = generate_cert(project, name)
191+
self.key = key
192+
self.cert = cert
193+
194+
def connect_settings(self, ip_addrs: Optional[Dict] = None) -> str:
195+
"""
196+
Mock data for the following API:
197+
https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance}/connectSettings
198+
"""
199+
server_ca_cert = self_signed_cert(self.cert, self.key)
200+
ip_addrs = ip_addrs if ip_addrs else self.ip_addrs
201+
ip_addresses = [
202+
{"type": key, "ipAddress": value} for key, value in ip_addrs.items()
203+
]
204+
return json.dumps(
205+
{
206+
"kind": "sql#connectSettings",
207+
"serverCaCert": {
208+
"cert": server_ca_cert,
209+
"instance": self.name,
210+
"expirationTime": str(
211+
datetime.datetime.utcnow() + datetime.timedelta(minutes=10)
212+
),
213+
},
214+
"ipAddresses": ip_addresses,
215+
"region": self.region,
216+
"databaseVersion": self.db_version,
217+
"backendType": self.backend_type,
218+
}
219+
)
220+
221+
def generate_ephemeral(self, client_bytes: str) -> str:
222+
"""
223+
Mock data for the following API:
224+
https://sqladmin.googleapis.com/sql/v1beta4/projects/{project}/instances/{instance}:generateEphemeralCert
225+
"""
226+
client_key: rsa.RSAPublicKey = serialization.load_pem_public_key(
227+
client_bytes.encode("UTF-8"), default_backend()
228+
) # type: ignore
229+
ephemeral_cert = client_key_signed_cert(self.cert, self.key, client_key)
230+
return json.dumps(
231+
{
232+
"ephemeralCert": {
233+
"kind": "sql#sslCert",
234+
"cert": ephemeral_cert,
235+
"expirationTime": str(
236+
datetime.datetime.utcnow() + datetime.timedelta(minutes=10)
237+
),
238+
}
239+
}
240+
)

0 commit comments

Comments
 (0)