Skip to content

Commit df9d647

Browse files
authored
Merge pull request #133 from PolicyEngine/350_use_caching_google_storage_client
download_file_from_gcs now uses cache.
2 parents 9ca329c + 566a067 commit df9d647

File tree

3 files changed

+66
-15
lines changed

3 files changed

+66
-15
lines changed

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
fixed:
4+
- downloads from google storage should now be properly cached.
Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,23 @@
1+
from .data.caching_google_storage_client import CachingGoogleStorageClient
2+
import asyncio
3+
from pathlib import Path
4+
5+
_caching_client: CachingGoogleStorageClient | None = None
6+
7+
8+
def _clear_client():
9+
global _caching_client
10+
_caching_client = None
11+
12+
13+
def _get_client():
14+
global _caching_client
15+
if _caching_client is not None:
16+
return _caching_client
17+
_caching_client = CachingGoogleStorageClient()
18+
return _caching_client
19+
20+
121
def download_file_from_gcs(
222
bucket_name: str, file_name: str, destination_path: str
323
) -> None:
@@ -12,18 +32,6 @@ def download_file_from_gcs(
1232
Returns:
1333
None
1434
"""
15-
from google.cloud import storage
16-
17-
# Initialize a client
18-
client = storage.Client()
19-
20-
# Get the bucket
21-
bucket = client.bucket(bucket_name)
22-
23-
# Create a blob object from the file name
24-
blob = bucket.blob(file_name)
25-
26-
# Download the file to a local path
27-
blob.download_to_filename(destination_path)
28-
29-
return destination_path
35+
asyncio.run(
36+
_get_client().download(bucket_name, file_name, Path(destination_path))
37+
)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from unittest import TestCase
2+
from unittest.mock import patch
3+
import pytest
4+
from pathlib import Path
5+
from policyengine.utils.google_cloud_bucket import (
6+
download_file_from_gcs,
7+
_clear_client,
8+
)
9+
10+
11+
class TestGoogleCloudBucket(TestCase):
12+
def setUp(self):
13+
_clear_client()
14+
15+
@patch(
16+
"policyengine.utils.google_cloud_bucket.CachingGoogleStorageClient",
17+
autospec=True,
18+
)
19+
def test_download_uses_storage_client(self, client_class):
20+
client_instance = client_class.return_value
21+
download_file_from_gcs(
22+
"TEST_BUCKET", "TEST/FILE/NAME.TXT", "TARGET/PATH"
23+
)
24+
client_instance.download.assert_called_with(
25+
"TEST_BUCKET", "TEST/FILE/NAME.TXT", Path("TARGET/PATH")
26+
)
27+
28+
@patch(
29+
"policyengine.utils.google_cloud_bucket.CachingGoogleStorageClient",
30+
autospec=True,
31+
)
32+
def test_download_only_creates_client_once(self, client_class):
33+
download_file_from_gcs(
34+
"TEST_BUCKET", "TEST/FILE/NAME.TXT", "TARGET/PATH"
35+
)
36+
download_file_from_gcs(
37+
"TEST_BUCKET", "TEST/FILE/NAME.TXT", "ANOTHER/PATH"
38+
)
39+
client_class.assert_called_once()

0 commit comments

Comments
 (0)