Skip to content

Commit d2fd465

Browse files
feat: use non-blocking disk read/writes (#1142)
Python's standard library read and writes are blocking IO. This PR switches to use aiofiles which is asyncio native approach to read/write to disk.
1 parent f4ba6bb commit d2fd465

File tree

7 files changed

+24
-19
lines changed

7 files changed

+24
-19
lines changed

google/cloud/sql/connector/connection_info.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,10 @@
1717
from dataclasses import dataclass
1818
import logging
1919
import ssl
20-
from tempfile import TemporaryDirectory
2120
from typing import Any, Dict, Optional, TYPE_CHECKING
2221

22+
from aiofiles.tempfile import TemporaryDirectory
23+
2324
from google.cloud.sql.connector.exceptions import CloudSQLIPTypeError
2425
from google.cloud.sql.connector.exceptions import TLSVersionError
2526
from google.cloud.sql.connector.utils import write_to_file
@@ -45,7 +46,7 @@ class ConnectionInfo:
4546
expiration: datetime.datetime
4647
context: Optional[ssl.SSLContext] = None
4748

48-
def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext:
49+
async def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext:
4950
"""Constructs a SSL/TLS context for the given connection info.
5051
5152
Cache the SSL context to ensure we don't read from disk repeatedly when
@@ -83,8 +84,8 @@ def create_ssl_context(self, enable_iam_auth: bool = False) -> ssl.SSLContext:
8384
# tmpdir and its contents are automatically deleted after the CA cert
8485
# and ephemeral cert are loaded into the SSLcontext. The values
8586
# need to be written to files in order to be loaded by the SSLContext
86-
with TemporaryDirectory() as tmpdir:
87-
ca_filename, cert_filename, key_filename = write_to_file(
87+
async with TemporaryDirectory() as tmpdir:
88+
ca_filename, cert_filename, key_filename = await write_to_file(
8889
tmpdir, self.server_ca_cert, self.client_cert, self.private_key
8990
)
9091
context.load_cert_chain(cert_filename, keyfile=key_filename)

google/cloud/sql/connector/connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,14 +365,14 @@ async def connect_async(
365365
if driver in ASYNC_DRIVERS:
366366
return await connector(
367367
ip_address,
368-
conn_info.create_ssl_context(enable_iam_auth),
368+
await conn_info.create_ssl_context(enable_iam_auth),
369369
**kwargs,
370370
)
371371
# synchronous drivers are blocking and run using executor
372372
connect_partial = partial(
373373
connector,
374374
ip_address,
375-
conn_info.create_ssl_context(enable_iam_auth),
375+
await conn_info.create_ssl_context(enable_iam_auth),
376376
**kwargs,
377377
)
378378
return await self._loop.run_in_executor(None, connect_partial)

google/cloud/sql/connector/utils.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
"""
16+
1617
from typing import Tuple
1718

19+
import aiofiles
1820
from cryptography.hazmat.backends import default_backend
1921
from cryptography.hazmat.primitives import serialization
2022
from cryptography.hazmat.primitives.asymmetric import rsa
@@ -57,7 +59,7 @@ async def generate_keys() -> Tuple[bytes, str]:
5759
return priv_key, pub_key
5860

5961

60-
def write_to_file(
62+
async def write_to_file(
6163
dir_path: str, serverCaCert: str, ephemeralCert: str, priv_key: bytes
6264
) -> Tuple[str, str, str]:
6365
"""
@@ -68,12 +70,12 @@ def write_to_file(
6870
cert_filename = f"{dir_path}/cert.pem"
6971
key_filename = f"{dir_path}/priv.pem"
7072

71-
with open(ca_filename, "w+") as ca_out:
72-
ca_out.write(serverCaCert)
73-
with open(cert_filename, "w+") as ephemeral_out:
74-
ephemeral_out.write(ephemeralCert)
75-
with open(key_filename, "wb") as priv_out:
76-
priv_out.write(priv_key)
73+
async with aiofiles.open(ca_filename, "w+") as ca_out:
74+
await ca_out.write(serverCaCert)
75+
async with aiofiles.open(cert_filename, "w+") as ephemeral_out:
76+
await ephemeral_out.write(ephemeralCert)
77+
async with aiofiles.open(key_filename, "wb") as priv_out:
78+
await priv_out.write(priv_key)
7779

7880
return (ca_filename, cert_filename, key_filename)
7981

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
aiofiles==24.1.0
12
aiohttp==3.9.5
23
cryptography==42.0.8
34
Requests==2.32.3

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
)
2626
release_status = "Development Status :: 5 - Production/Stable"
2727
dependencies = [
28+
"aiofiles",
2829
"aiohttp",
2930
"cryptography>=42.0.0",
3031
"Requests",

tests/unit/mocks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
import datetime
2020
import json
2121
import ssl
22-
from tempfile import TemporaryDirectory
2322
from typing import Any, Callable, Dict, Literal, Optional, Tuple
2423

24+
from aiofiles.tempfile import TemporaryDirectory
2525
from aiohttp import web
2626
from cryptography import x509
2727
from cryptography.hazmat.backends import default_backend
@@ -203,8 +203,8 @@ async def create_ssl_context() -> ssl.SSLContext:
203203
# build default ssl.SSLContext
204204
context = ssl.create_default_context()
205205
# load ssl.SSLContext with certs
206-
with TemporaryDirectory() as tmpdir:
207-
ca_filename, cert_filename, key_filename = write_to_file(
206+
async with TemporaryDirectory() as tmpdir:
207+
ca_filename, cert_filename, key_filename = await write_to_file(
208208
tmpdir, server_ca_cert, ephemeral_cert, client_private
209209
)
210210
context.load_cert_chain(cert_filename, keyfile=key_filename)

tests/unit/test_instance.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -367,14 +367,14 @@ async def test_AutoIAMAuthNotSupportedError(fake_client: CloudSQLClient) -> None
367367
await cache._current
368368

369369

370-
def test_ConnectionInfo_caches_sslcontext() -> None:
370+
async def test_ConnectionInfo_caches_sslcontext() -> None:
371371
info = ConnectionInfo(
372372
"cert", "cert", "key".encode(), {}, "POSTGRES", datetime.datetime.now()
373373
)
374374
# context should default to None
375375
assert info.context is None
376376
# cache a 'context'
377377
info.context = "context"
378-
# caling create_ssl_context should no-op with an existing 'context'
379-
info.create_ssl_context()
378+
# calling create_ssl_context should no-op with an existing 'context'
379+
await info.create_ssl_context()
380380
assert info.context == "context"

0 commit comments

Comments
 (0)