Skip to content

Commit d8c23a2

Browse files
author
Uziel Silva
committed
fix(main) Increase code coverage to 94%
Changelog: - Add unit tests for proxy - Add test case to connector for drivers that require the local proxy - Make proper adjustments to code
1 parent 8c3ce21 commit d8c23a2

File tree

4 files changed

+118
-6
lines changed

4 files changed

+118
-6
lines changed

google/cloud/sql/connector/connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from google.cloud.sql.connector.lazy import LazyRefreshCache
3838
from google.cloud.sql.connector.monitored_cache import MonitoredCache
3939
import google.cloud.sql.connector.pg8000 as pg8000
40-
from google.cloud.sql.connector.proxy import start_local_proxy
40+
import google.cloud.sql.connector.proxy as proxy
4141
import google.cloud.sql.connector.psycopg as psycopg
4242
import google.cloud.sql.connector.pymysql as pymysql
4343
import google.cloud.sql.connector.pytds as pytds
@@ -402,7 +402,7 @@ async def connect_async(
402402
if driver in LOCAL_PROXY_DRIVERS:
403403
local_socket_path = kwargs.pop("local_socket_path", "/tmp/connector-socket")
404404
host = local_socket_path
405-
self._proxy = start_local_proxy(
405+
self._proxy = proxy.start_local_proxy(
406406
sock,
407407
socket_path=f"{local_socket_path}/.s.PGSQL.{SERVER_PROXY_PORT}",
408408
loop=self._loop

google/cloud/sql/connector/proxy.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,9 @@ def start_local_proxy(
7474
async def local_communication(
7575
unix_socket, ssl_sock, socket_path, loop
7676
):
77+
client, _ = await loop.sock_accept(unix_socket)
78+
7779
try:
78-
client, _ = await loop.sock_accept(unix_socket)
79-
8080
while True:
8181
data = await loop.sock_recv(client, LOCAL_PROXY_MAX_MESSAGE_SIZE)
8282
if not data:
@@ -85,8 +85,6 @@ async def local_communication(
8585
ssl_sock.sendall(data)
8686
response = ssl_sock.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE)
8787
await loop.sock_sendall(client, response)
88-
except Exception:
89-
pass
9088
finally:
9189
client.close()
9290
os.remove(socket_path) # Clean up the socket file

tests/unit/test_connector.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import asyncio
1818
import os
19+
import socket
20+
import ssl
1921
from typing import Union
2022

2123
from aiohttp import ClientResponseError
@@ -31,6 +33,7 @@
3133
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
3234
from google.cloud.sql.connector.exceptions import IncompatibleDriverError
3335
from google.cloud.sql.connector.instance import RefreshAheadCache
36+
from google.cloud.sql.connector.proxy import start_local_proxy
3437

3538

3639
@pytest.mark.asyncio
@@ -279,6 +282,42 @@ async def test_Connector_connect_async(
279282
# verify connector made connection call
280283
assert connection is True
281284

285+
@pytest.mark.usefixtures("proxy_server")
286+
@pytest.mark.asyncio
287+
async def test_Connector_connect_local_proxy(
288+
fake_credentials: Credentials, fake_client: CloudSQLClient, context: ssl.SSLContext
289+
) -> None:
290+
"""Test that Connector.connect can launch start_local_proxy."""
291+
async with Connector(
292+
credentials=fake_credentials, loop=asyncio.get_running_loop()
293+
) as connector:
294+
connector._client = fake_client
295+
socket_path = "/tmp/connector-socket/socket"
296+
ip_addr = "127.0.0.1"
297+
ssl_sock = context.wrap_socket(
298+
socket.create_connection((ip_addr, 3307)),
299+
server_hostname=ip_addr,
300+
)
301+
loop = asyncio.get_running_loop()
302+
task = start_local_proxy(ssl_sock, socket_path, loop)
303+
# patch db connection creation
304+
with patch("google.cloud.sql.connector.proxy.start_local_proxy") as mock_proxy:
305+
with patch("google.cloud.sql.connector.psycopg.connect") as mock_connect:
306+
mock_connect.return_value = True
307+
mock_proxy.return_value = task
308+
connection = await connector.connect_async(
309+
"test-project:test-region:test-instance",
310+
"psycopg",
311+
user="my-user",
312+
password="my-pass",
313+
db="my-db",
314+
local_socket_path=socket_path,
315+
)
316+
# verify connector called local proxy
317+
mock_connect.assert_called_once()
318+
mock_proxy.assert_called_once()
319+
assert connection is True
320+
282321

283322
@pytest.mark.asyncio
284323
async def test_create_async_connector(fake_credentials: Credentials) -> None:

tests/unit/test_proxy.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
Copyright 2025 Google LLC
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
https://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import asyncio
18+
import socket
19+
import ssl
20+
from typing import Any
21+
22+
from mock import Mock
23+
import pytest
24+
25+
from google.cloud.sql.connector import proxy
26+
27+
LOCAL_PROXY_MAX_MESSAGE_SIZE = 10485760
28+
29+
@pytest.mark.usefixtures("proxy_server")
30+
@pytest.mark.asyncio
31+
async def test_proxy_creates_folder(context: ssl.SSLContext, kwargs: Any) -> None:
32+
"""Test to verify that the proxy server is getting back the task."""
33+
ip_addr = "127.0.0.1"
34+
path = "/tmp/connector-socket/socket"
35+
sock = context.wrap_socket(
36+
socket.create_connection((ip_addr, 3307)),
37+
server_hostname=ip_addr,
38+
)
39+
loop = asyncio.get_running_loop()
40+
41+
task = proxy.start_local_proxy(sock, path, loop)
42+
assert (task is not None)
43+
44+
proxy_task = asyncio.gather(task)
45+
try:
46+
await asyncio.wait_for(proxy_task, timeout=0.1)
47+
except TimeoutError:
48+
pass # This task runs forever so it is expected to throw this exception
49+
50+
@pytest.mark.usefixtures("proxy_server")
51+
@pytest.mark.asyncio
52+
async def test_local_proxy_communication(context: ssl.SSLContext, kwargs: Any) -> None:
53+
"""Test to verify that the communication is getting through."""
54+
socket_path = "/tmp/connector-socket/socket"
55+
ssl_sock = Mock(spec=ssl.SSLSocket)
56+
loop = asyncio.get_running_loop()
57+
58+
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as client:
59+
ssl_sock.recv.return_value = b"Received"
60+
61+
task = proxy.start_local_proxy(ssl_sock, socket_path, loop)
62+
63+
client.connect(socket_path)
64+
client.sendall(b"Test")
65+
await asyncio.sleep(1)
66+
67+
ssl_sock.sendall.assert_called_with(b"Test")
68+
response = client.recv(LOCAL_PROXY_MAX_MESSAGE_SIZE)
69+
assert (response == b"Received")
70+
71+
client.close()
72+
await asyncio.sleep(1)
73+
74+
proxy_task = asyncio.gather(task)
75+
await asyncio.wait_for(proxy_task, timeout=2)

0 commit comments

Comments
 (0)