Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
added:
- new class CachingGoogleStorageClient for locally caching gs files to disk.
2 changes: 2 additions & 0 deletions policyengine/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .caching_google_storage_client import CachingGoogleStorageClient
from .simplified_google_storage_client import SimplifiedGoogleStorageClient
85 changes: 85 additions & 0 deletions policyengine/utils/data/caching_google_storage_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from contextlib import AbstractContextManager
import diskcache
from pathlib import Path
from policyengine_core.data.dataset import atomic_write
import logging
from .simplified_google_storage_client import SimplifiedGoogleStorageClient

logger = logging.getLogger(__name__)


class CachingGoogleStorageClient(AbstractContextManager):
"""
Client for downloaded resources from a google storage bucket only when the CRC
of the blob changes.
"""

def __init__(self):
self.client = SimplifiedGoogleStorageClient()
self.cache = diskcache.Cache()

def _data_key(self, bucket: str, key: str) -> str:
return f"{bucket}.{key}.data"

# To absolutely 100% avoid any possible issue with file corruption or thread contention
# always replace the current target file with whatever we have cached as an atomic write.
async def download(self, bucket: str, key: str, target: Path):
"""
Atomically write the latest version of the cloud storage blob to the target path.
"""
await self.sync(bucket, key)
data = self.cache.get(self._data_key(bucket, key))
if type(data) is bytes:
logger.debug(
f"Copying downloaded data for {bucket}, {key} to {target}"
)
atomic_write(target, data)
return
raise Exception("Expected data for blob to be cached as bytes")

# If the crc has changed from what we downloaded last time download it again.
# then update the CRC to whatever we actually downloaded.
async def sync(self, bucket: str, key: str) -> None:
"""
Cache the resource if the CRC has changed.
"""
logger.info(f"Syncing {bucket}, {key} to cache")
datakey = f"{bucket}.{key}.data"
crckey = f"{bucket}.{key}.crc"

crc = self.client.crc32c(bucket, key)
if crc is None:
raise Exception(f"Unable to find {key} in bucket {bucket}")

prev_crc = self.cache.get(crckey, default=None)
logger.debug(f"Previous crc for {bucket}, {key} was {prev_crc}")
if prev_crc == crc:
logger.info(
f"Cache exists and crc is unchanged for {bucket}, {key}."
)
return

[content, downloaded_crc] = await self.client.download(bucket, key)
logger.info(
f"Downloaded new version of {bucket}, {key} with crc {downloaded_crc}"
)

# atomic transaction to update both the data and the metadata
# at the same time.
with self.cache as c:
logger.debug(f"Updating cache...")
self.cache.set(datakey, content)
# Whatever the CRC was before we downloaded, we set the cache CRC
# to the CRC reported by the download itself to avoid race conditions.
self.cache.set(crckey, downloaded_crc)

def clear(self):
self.cache.clear()

def __enter__(self) -> "CachingGoogleStorageClient":
return self

def __exit__(self, exc_type, exc_value, traceback):
"""Raise any exception triggered within the runtime context."""
self.clear()
return None
46 changes: 46 additions & 0 deletions policyengine/utils/data/simplified_google_storage_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import asyncio
from policyengine_core.data.dataset import atomic_write
import logging
from google.cloud.storage import Client

logger = logging.getLogger(__name__)


class SimplifiedGoogleStorageClient:
"""
Class separating out just the interactions with google storage required to
cache downloaded files.

Simplifies the dependent code and unit testing.
"""

def __init__(self):
self.client = Client()

def crc32c(self, bucket: str, key: str) -> str | None:
"""
get the current CRC of the specified blob. None if it doesn't exist.
"""
logger.debug(f"Getting crc for {bucket}, {key}")
blob = self.client.bucket(bucket).blob(key)
blob.reload()
logger.debug(f"Crc is {blob.crc32c}")
return blob.crc32c

async def download(self, bucket: str, key: str) -> tuple[bytes, str]:
"""
get the blob content and associated CRC from google storage.
"""
logger.debug(f"Downloading {bucket}, {key}")
blob = self.client.bucket(bucket).blob(key)

# async implmentation as per https://github.com/googleapis/python-storage/blob/main/samples/snippets/storage_async_download.py
def download():
return blob.download_as_bytes()

loop = asyncio.get_running_loop()
result = await loop.run_in_executor(None, download)
# According to documentation blob.crc32c is updated as a side effect of
# downloading the content. As a result this should now be the crc of the downloaded
# content (i.e. there is not a race condition where it's getting the CRC from the cloud)
return (result, blob.crc32c)
6 changes: 4 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ dependencies = [
"policyengine_core>=3.10",
"policyengine-uk",
"policyengine-us>=1.213.1",
"diskcache (>=5.6.3,<6.0.0)",
"google-cloud-storage (>=3.1.0,<4.0.0)",
"microdf_python",
"getpass4",
"pydantic",
"google-cloud-storage",
"pydantic"
]

[project.optional-dependencies]
Expand All @@ -32,6 +33,7 @@ dev = [
"yaml-changelog>=0.1.7",
"itables",
"build",
"pytest-asyncio>=0.26.0",
]

[tool.setuptools]
Expand Down
Empty file added tests/utils/data/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions tests/utils/data/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pytest
from unittest.mock import patch


class MockedStorageSupport:
def __init__(self, mock_simple_storage_client):
self.mock_simple_storage_client = mock_simple_storage_client

def given_stored_data(self, data: str, crc: str):
self.mock_simple_storage_client.crc32c.return_value = crc
self.mock_simple_storage_client.download.return_value = (
data.encode(),
crc,
)

def given_crc_changes_on_download(
self, data: str, initial_crc: str, download_crc: str
):
self.mock_simple_storage_client.crc32c.return_value = initial_crc
self.mock_simple_storage_client.download.return_value = (
data.encode(),
download_crc,
)


@pytest.fixture()
def mocked_storage():
with patch(
"policyengine.utils.data.caching_google_storage_client.SimplifiedGoogleStorageClient",
autospec=True,
) as mock_class:
mock_instance = mock_class.return_value
yield MockedStorageSupport(mock_instance)
107 changes: 107 additions & 0 deletions tests/utils/data/test_caching_google_storage_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
from pathlib import Path
import pytest
import tempfile
from unittest.mock import MagicMock, create_autospec, patch
from policyengine.utils.data import CachingGoogleStorageClient
from tests.utils.data.conftest import MockedStorageSupport


class TestCachingGoogleStorageClient:
@pytest.mark.asyncio
async def test_when_cache_miss__then_download_file(
self, mocked_storage: MockedStorageSupport
):
with CachingGoogleStorageClient() as caching_client:
with tempfile.TemporaryDirectory() as tmpdir:
mocked_storage.given_stored_data("TEST DATA", "TEST_CRC")
await caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
open(Path(tmpdir, "output.txt")).readline() == "TEST DATA"
)

@pytest.mark.asyncio
async def test_when_cache_hit__then_use_cached_value(
self, mocked_storage: MockedStorageSupport
):
with CachingGoogleStorageClient() as caching_client:
with tempfile.TemporaryDirectory() as tmpdir:
mocked_storage.given_stored_data(
"INITIAL TEST DATA", "TEST_CRC"
)
await caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
open(Path(tmpdir, "output.txt")).readline()
== "INITIAL TEST DATA"
)

mocked_storage.given_stored_data(
"CRC DID NOT CHANGE SO YOU SHOULD NOT SEE THIS", "TEST_CRC"
)
await caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
open(Path(tmpdir, "output.txt")).readline()
== "INITIAL TEST DATA"
)

@pytest.mark.asyncio
async def test_when_crc_updated__then_redownload(
self, mocked_storage: MockedStorageSupport
):
with CachingGoogleStorageClient() as caching_client:
with tempfile.TemporaryDirectory() as tmpdir:
mocked_storage.given_stored_data(
"INITIAL TEST DATA", "TEST_CRC"
)
await caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
open(Path(tmpdir, "output.txt")).readline()
== "INITIAL TEST DATA"
)

mocked_storage.given_stored_data(
"UPDATED_TEST_DATA", "UPDATED_TEST_CRC"
)
await caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
open(Path(tmpdir, "output.txt")).readline()
== "UPDATED_TEST_DATA"
)

@pytest.mark.asyncio
async def test_when_crc_updated_on_download__then_store_downloaded_crc(
self, mocked_storage: MockedStorageSupport
):
with CachingGoogleStorageClient() as caching_client:
with tempfile.TemporaryDirectory() as tmpdir:
mocked_storage.given_crc_changes_on_download(
"FINAL CONTENT", "INITIAL_CRC", "DOWNLOADED_CRC"
)
await caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
open(Path(tmpdir, "output.txt")).readline()
== "FINAL CONTENT"
)

mocked_storage.given_stored_data(
"YOU SHOULD NOT SEE THIS BECAUSE THE CRC IS UNCHANGED FROM DOWNLOADED",
"DOWNLOADED_CRC",
)
await caching_client.download(
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
)
assert (
open(Path(tmpdir, "output.txt")).readline()
== "FINAL CONTENT"
)
43 changes: 43 additions & 0 deletions tests/utils/data/test_simplified_google_storage_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from unittest.mock import patch
import pytest
from policyengine.utils.data import SimplifiedGoogleStorageClient


class TestSimplifiedGoogleStorageClient:
@patch(
"policyengine.utils.data.simplified_google_storage_client.Client",
autospec=True,
)
def test_crc32c__gets_crc(self, mock_client_class):
mock_instance = mock_client_class.return_value
bucket = mock_instance.bucket.return_value
blob = bucket.blob.return_value

blob.crc32c = "TEST_CRC"

client = SimplifiedGoogleStorageClient()
assert client.crc32c("bucket_name", "content.txt") == "TEST_CRC"
mock_instance.bucket.assert_called_with("bucket_name")
bucket.blob.assert_called_with("content.txt")
blob.reload.assert_called()

@pytest.mark.asyncio
@patch(
"policyengine.utils.data.simplified_google_storage_client.Client",
autospec=True,
)
async def test_download__downloads_content(self, mock_client_class):
mock_instance = mock_client_class.return_value
bucket = mock_instance.bucket.return_value
blob = bucket.blob.return_value

blob.download_as_bytes.return_value = "hello, world".encode()
blob.crc32c = "TEST_CRC"

client = SimplifiedGoogleStorageClient()
[data, crc] = await client.download("bucket", "blob.txt")
assert data == "hello, world".encode()
assert crc == "TEST_CRC"

mock_instance.bucket.assert_called_with("bucket")
bucket.blob.assert_called_with("blob.txt")
Loading