Skip to content

Commit ba434e7

Browse files
feat: use non-blocking disk read/writes (#360)
Python's standard library read and writes which are blocking I/O. This PR switches to use aiofiles which is non-blocking approach to read/write to disk.
1 parent 318445f commit ba434e7

File tree

9 files changed

+38
-23
lines changed

9 files changed

+38
-23
lines changed

google/cloud/alloydb/connector/async_connector.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,9 @@ def get_authentication_token() -> str:
194194
if enable_iam_auth:
195195
kwargs["password"] = get_authentication_token
196196
try:
197-
return await connector(ip_address, conn_info.create_ssl_context(), **kwargs)
197+
return await connector(
198+
ip_address, await conn_info.create_ssl_context(), **kwargs
199+
)
198200
except Exception:
199201
# we attempt a force refresh, then throw the error
200202
await cache.force_refresh()

google/cloud/alloydb/connector/connection_info.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
from dataclasses import dataclass
1818
import logging
1919
import ssl
20-
from tempfile import TemporaryDirectory
2120
from typing import Dict, List, Optional, TYPE_CHECKING
2221

22+
from aiofiles.tempfile import TemporaryDirectory
23+
2324
from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
2425
from google.cloud.alloydb.connector.utils import _write_to_file
2526

@@ -45,7 +46,7 @@ class ConnectionInfo:
4546
expiration: datetime.datetime
4647
context: Optional[ssl.SSLContext] = None
4748

48-
def create_ssl_context(self) -> ssl.SSLContext:
49+
async def create_ssl_context(self) -> ssl.SSLContext:
4950
"""Constructs a SSL/TLS context for the given connection info.
5051
5152
Cache the SSL context to ensure we don't read from disk repeatedly when
@@ -66,8 +67,8 @@ def create_ssl_context(self) -> ssl.SSLContext:
6667
# tmpdir and its contents are automatically deleted after the CA cert
6768
# and cert chain are loaded into the SSLcontext. The values
6869
# need to be written to files in order to be loaded by the SSLContext
69-
with TemporaryDirectory() as tmpdir:
70-
ca_filename, cert_chain_filename, key_filename = _write_to_file(
70+
async with TemporaryDirectory() as tmpdir:
71+
ca_filename, cert_chain_filename, key_filename = await _write_to_file(
7172
tmpdir, self.ca_cert, self.cert_chain, self.key
7273
)
7374
context.load_cert_chain(cert_chain_filename, keyfile=key_filename)

google/cloud/alloydb/connector/connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
215215
metadata_partial = partial(
216216
self.metadata_exchange,
217217
ip_address,
218-
conn_info.create_ssl_context(),
218+
await conn_info.create_ssl_context(),
219219
enable_iam_auth,
220220
driver,
221221
)

google/cloud/alloydb/connector/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,12 @@
1616

1717
from typing import List, Tuple
1818

19+
import aiofiles
1920
from cryptography.hazmat.primitives import serialization
2021
from cryptography.hazmat.primitives.asymmetric import rsa
2122

2223

23-
def _write_to_file(
24+
async def _write_to_file(
2425
dir_path: str, ca_cert: str, cert_chain: List[str], key: rsa.RSAPrivateKey
2526
) -> Tuple[str, str, str]:
2627
"""
@@ -37,12 +38,12 @@ def _write_to_file(
3738
encryption_algorithm=serialization.NoEncryption(),
3839
)
3940

40-
with open(ca_filename, "w+") as ca_out:
41-
ca_out.write(ca_cert)
42-
with open(cert_chain_filename, "w+") as chain_out:
43-
chain_out.write("".join(cert_chain))
44-
with open(key_filename, "wb") as priv_out:
45-
priv_out.write(key_bytes)
41+
async with aiofiles.open(ca_filename, "w+") as ca_out:
42+
await ca_out.write(ca_cert)
43+
async with aiofiles.open(cert_chain_filename, "w+") as chain_out:
44+
await chain_out.write("".join(cert_chain))
45+
async with aiofiles.open(key_filename, "wb") as priv_out:
46+
await priv_out.write(key_bytes)
4647

4748
return (ca_filename, cert_chain_filename, key_filename)
4849

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
aiofiles==24.1.0
12
aiohttp==3.9.5
23
cryptography==42.0.8
34
google-auth==2.32.0

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
release_status = "Development Status :: 5 - Production/Stable"
2323
dependencies = [
24+
"aiofiles",
2425
"aiohttp",
2526
"cryptography>=42.0.0",
2627
"requests",

tests/unit/conftest.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,13 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import asyncio
1516
import socket
1617
import ssl
17-
from tempfile import TemporaryDirectory
1818
from threading import Thread
1919
from typing import Generator
2020

21+
from aiofiles.tempfile import TemporaryDirectory
2122
from mocks import FakeAlloyDBClient
2223
from mocks import FakeCredentials
2324
from mocks import FakeInstance
@@ -42,7 +43,7 @@ def fake_client(fake_instance: FakeInstance) -> FakeAlloyDBClient:
4243
return FakeAlloyDBClient(fake_instance)
4344

4445

45-
def start_proxy_server(instance: FakeInstance) -> None:
46+
async def start_proxy_server(instance: FakeInstance) -> None:
4647
"""Run local proxy server capable of performing metadata exchange"""
4748
ip_address = "127.0.0.1"
4849
port = 5433
@@ -55,8 +56,8 @@ def start_proxy_server(instance: FakeInstance) -> None:
5556
# tmpdir and its contents are automatically deleted after the CA cert
5657
# and cert chain are loaded into the SSLcontext. The values
5758
# need to be written to files in order to be loaded by the SSLContext
58-
with TemporaryDirectory() as tmpdir:
59-
_, cert_chain_filename, key_filename = _write_to_file(
59+
async with TemporaryDirectory() as tmpdir:
60+
_, cert_chain_filename, key_filename = await _write_to_file(
6061
tmpdir, server, [server, root], instance.server_key
6162
)
6263
context.load_cert_chain(cert_chain_filename, key_filename)
@@ -76,7 +77,15 @@ def start_proxy_server(instance: FakeInstance) -> None:
7677
@pytest.fixture(scope="session")
7778
def proxy_server(fake_instance: FakeInstance) -> Generator:
7879
"""Run local proxy server capable of performing metadata exchange"""
79-
thread = Thread(target=start_proxy_server, args=(fake_instance,), daemon=True)
80+
thread = Thread(
81+
target=asyncio.run,
82+
args=(
83+
start_proxy_server(
84+
fake_instance,
85+
),
86+
),
87+
daemon=True,
88+
)
8089
thread.start()
8190
yield thread
8291
thread.join()

tests/unit/mocks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def get_preferred_ip(self, ip_type: Any) -> Tuple[str, Any]:
370370
f.set_result("10.0.0.1")
371371
return f
372372

373-
def create_ssl_context(self) -> None:
373+
async def create_ssl_context(self) -> None:
374374
return None
375375

376376
async def force_refresh(self) -> None:

tests/unit/test_connection_info.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError
3030

3131

32-
def test_ConnectionInfo_init_(fake_instance: FakeInstance) -> None:
32+
async def test_ConnectionInfo_init_(fake_instance: FakeInstance) -> None:
3333
"""
3434
Test to check whether the __init__ method of ConnectionInfo
3535
can correctly initialize TLS context.
@@ -58,19 +58,19 @@ def test_ConnectionInfo_init_(fake_instance: FakeInstance) -> None:
5858
fake_instance.ip_addrs,
5959
datetime.now(timezone.utc) + timedelta(minutes=10),
6060
)
61-
context = conn_info.create_ssl_context()
61+
context = await conn_info.create_ssl_context()
6262
# verify TLS requirements
6363
assert context.minimum_version == ssl.TLSVersion.TLSv1_3
6464

6565

66-
def test_ConnectionInfo_caches_sslcontext() -> None:
66+
async def test_ConnectionInfo_caches_sslcontext() -> None:
6767
info = ConnectionInfo(["cert"], "cert", "key".encode(), {}, datetime.now())
6868
# context should default to None
6969
assert info.context is None
7070
# cache a 'context'
7171
info.context = "context"
7272
# calling create_ssl_context should no-op with an existing 'context'
73-
info.create_ssl_context()
73+
await info.create_ssl_context()
7474
assert info.context == "context"
7575

7676

0 commit comments

Comments
 (0)