Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
fixed:
- google storage caching is now fully sync, not async and reenabled.
10 changes: 5 additions & 5 deletions policyengine/utils/data/caching_google_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ def _data_key(self, bucket: str, key: str) -> str:

# To absolutely 100% avoid any possible issue with file corruption or thread contention
# always replace the current target file with whatever we have cached as an atomic write.
async def download(self, bucket: str, key: str, target: Path):
def download(self, bucket: str, key: str, target: Path):
"""
Atomically write the latest version of the cloud storage blob to the target path.
"""
await self.sync(bucket, key)
self.sync(bucket, key)
data = self.cache.get(self._data_key(bucket, key))
if type(data) is bytes:
logger.debug(
logger.info(
f"Copying downloaded data for {bucket}, {key} to {target}"
)
atomic_write(target, data)
Expand All @@ -39,7 +39,7 @@ async def download(self, bucket: str, key: str, target: Path):

# If the crc has changed from what we downloaded last time download it again.
# then update the CRC to whatever we actually downloaded.
async def sync(self, bucket: str, key: str) -> None:
def sync(self, bucket: str, key: str) -> None:
"""
Cache the resource if the CRC has changed.
"""
Expand All @@ -59,7 +59,7 @@ async def sync(self, bucket: str, key: str) -> None:
)
return

[content, downloaded_crc] = await self.client.download(bucket, key)
[content, downloaded_crc] = self.client.download(bucket, key)
logger.info(
f"Downloaded new version of {bucket}, {key} with crc {downloaded_crc}"
)
Expand Down
9 changes: 2 additions & 7 deletions policyengine/utils/data/simplified_google_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,14 @@ def crc32c(self, bucket: str, key: str) -> str | None:
logger.debug(f"Crc is {blob.crc32c}")
return blob.crc32c

async def download(self, bucket: str, key: str) -> tuple[bytes, str]:
def download(self, bucket: str, key: str) -> tuple[bytes, str]:
"""
get the blob content and associated CRC from google storage.
"""
logger.debug(f"Downloading {bucket}, {key}")
blob = self.client.bucket(bucket).blob(key)

# async implmentation as per https://github.com/googleapis/python-storage/blob/main/samples/snippets/storage_async_download.py
def download():
return blob.download_as_bytes()

loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, download)
result = blob.download_as_bytes()
# According to documentation blob.crc32c is updated as a side effect of
# downloading the content. As a result this should now be the crc of the downloaded
# content (i.e. there is not a race condition where it's getting the CRC from the cloud)
Expand Down
9 changes: 5 additions & 4 deletions policyengine/utils/data_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ def download(
)

logging.info = print
if Path(filepath).exists():
logging.info(f"File {filepath} already exists. Skipping download.")
return filepath

# NOTE: tests will break on build if you don't default to huggingface.
if data_file.huggingface_repo is not None:
logging.info("Using Hugging Face for download.")
try:
Expand All @@ -43,6 +40,10 @@ def download(
except:
logging.info("Failed to download from Hugging Face.")

if Path(filepath).exists():
logging.info(f"File {filepath} already exists. Skipping download.")
return filepath

if data_file.gcs_bucket is not None:
logging.info("Using Google Cloud Storage for download.")
download_file_from_gcs(
Expand Down
36 changes: 21 additions & 15 deletions policyengine/utils/google_cloud_bucket.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,23 @@
from .data.caching_google_storage_client import CachingGoogleStorageClient
import asyncio
from pathlib import Path

_caching_client: CachingGoogleStorageClient | None = None


def _get_client():
global _caching_client
if _caching_client is not None:
return _caching_client
_caching_client = CachingGoogleStorageClient()
return _caching_client


def _clear_client():
global _caching_client
_caching_client = None


def download_file_from_gcs(
bucket_name: str, file_name: str, destination_path: str
) -> None:
Expand All @@ -12,18 +32,4 @@ def download_file_from_gcs(
Returns:
None
"""
from google.cloud import storage

# Initialize a client
client = storage.Client()

# Get the bucket
bucket = client.bucket(bucket_name)

# Create a blob object from the file name
blob = bucket.blob(file_name)

# Download the file to a local path
blob.download_to_filename(destination_path)

return destination_path
_get_client().download(bucket_name, file_name, Path(destination_path))
26 changes: 11 additions & 15 deletions tests/utils/data/test_caching_google_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,28 @@


class TestCachingGoogleStorageClient:
@pytest.mark.asyncio
async def test_when_cache_miss__then_download_file(
def test_when_cache_miss__then_download_file(
self, mocked_storage: MockedStorageSupport
):
with CachingGoogleStorageClient() as caching_client:
with tempfile.TemporaryDirectory() as tmpdir:
mocked_storage.given_stored_data("TEST DATA", "TEST_CRC")
await caching_client.download(
caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
open(Path(tmpdir, "output.txt")).readline() == "TEST DATA"
)

@pytest.mark.asyncio
async def test_when_cache_hit__then_use_cached_value(
def test_when_cache_hit__then_use_cached_value(
self, mocked_storage: MockedStorageSupport
):
with CachingGoogleStorageClient() as caching_client:
with tempfile.TemporaryDirectory() as tmpdir:
mocked_storage.given_stored_data(
"INITIAL TEST DATA", "TEST_CRC"
)
await caching_client.download(
caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
Expand All @@ -41,24 +39,23 @@ async def test_when_cache_hit__then_use_cached_value(
mocked_storage.given_stored_data(
"CRC DID NOT CHANGE SO YOU SHOULD NOT SEE THIS", "TEST_CRC"
)
await caching_client.download(
caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
open(Path(tmpdir, "output.txt")).readline()
== "INITIAL TEST DATA"
)

@pytest.mark.asyncio
async def test_when_crc_updated__then_redownload(
def test_when_crc_updated__then_redownload(
self, mocked_storage: MockedStorageSupport
):
with CachingGoogleStorageClient() as caching_client:
with tempfile.TemporaryDirectory() as tmpdir:
mocked_storage.given_stored_data(
"INITIAL TEST DATA", "TEST_CRC"
)
await caching_client.download(
caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
Expand All @@ -69,24 +66,23 @@ async def test_when_crc_updated__then_redownload(
mocked_storage.given_stored_data(
"UPDATED_TEST_DATA", "UPDATED_TEST_CRC"
)
await caching_client.download(
caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
open(Path(tmpdir, "output.txt")).readline()
== "UPDATED_TEST_DATA"
)

@pytest.mark.asyncio
async def test_when_crc_updated_on_download__then_store_downloaded_crc(
def test_when_crc_updated_on_download__then_store_downloaded_crc(
self, mocked_storage: MockedStorageSupport
):
with CachingGoogleStorageClient() as caching_client:
with tempfile.TemporaryDirectory() as tmpdir:
mocked_storage.given_crc_changes_on_download(
"FINAL CONTENT", "INITIAL_CRC", "DOWNLOADED_CRC"
)
await caching_client.download(
caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
Expand All @@ -98,7 +94,7 @@ async def test_when_crc_updated_on_download__then_store_downloaded_crc(
"YOU SHOULD NOT SEE THIS BECAUSE THE CRC IS UNCHANGED FROM DOWNLOADED",
"DOWNLOADED_CRC",
)
await caching_client.download(
caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
Expand Down
39 changes: 39 additions & 0 deletions tests/utils/data/test_google_cloud_bucket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from unittest import TestCase
from unittest.mock import patch
import pytest
from pathlib import Path
from policyengine.utils.google_cloud_bucket import (
download_file_from_gcs,
_clear_client,
)


class TestGoogleCloudBucket(TestCase):
def setUp(self):
_clear_client()

@patch(
"policyengine.utils.google_cloud_bucket.CachingGoogleStorageClient",
autospec=True,
)
def test_download_uses_storage_client(self, client_class):
client_instance = client_class.return_value
download_file_from_gcs(
"TEST_BUCKET", "TEST/FILE/NAME.TXT", "TARGET/PATH"
)
client_instance.download.assert_called_with(
"TEST_BUCKET", "TEST/FILE/NAME.TXT", Path("TARGET/PATH")
)

@patch(
"policyengine.utils.google_cloud_bucket.CachingGoogleStorageClient",
autospec=True,
)
def test_download_only_creates_client_once(self, client_class):
download_file_from_gcs(
"TEST_BUCKET", "TEST/FILE/NAME.TXT", "TARGET/PATH"
)
download_file_from_gcs(
"TEST_BUCKET", "TEST/FILE/NAME.TXT", "ANOTHER/PATH"
)
client_class.assert_called_once()
5 changes: 2 additions & 3 deletions tests/utils/data/test_simplified_google_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,11 @@ def test_crc32c__gets_crc(self, mock_client_class):
bucket.blob.assert_called_with("content.txt")
blob.reload.assert_called()

@pytest.mark.asyncio
@patch(
"policyengine.utils.data.simplified_google_storage_client.Client",
autospec=True,
)
async def test_download__downloads_content(self, mock_client_class):
def test_download__downloads_content(self, mock_client_class):
mock_instance = mock_client_class.return_value
bucket = mock_instance.bucket.return_value
blob = bucket.blob.return_value
Expand All @@ -35,7 +34,7 @@ async def test_download__downloads_content(self, mock_client_class):
blob.crc32c = "TEST_CRC"

client = SimplifiedGoogleStorageClient()
[data, crc] = await client.download("bucket", "blob.txt")
[data, crc] = client.download("bucket", "blob.txt")
assert data == "hello, world".encode()
assert crc == "TEST_CRC"

Expand Down
Loading