diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..dfa596e2 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + added: + - new class CachingGoogleStorageClient for locally caching gs files to disk. diff --git a/policyengine/utils/data/__init__.py b/policyengine/utils/data/__init__.py new file mode 100644 index 00000000..74b27825 --- /dev/null +++ b/policyengine/utils/data/__init__.py @@ -0,0 +1,2 @@ +from .caching_google_storage_client import CachingGoogleStorageClient +from .simplified_google_storage_client import SimplifiedGoogleStorageClient diff --git a/policyengine/utils/data/caching_google_storage_client.py b/policyengine/utils/data/caching_google_storage_client.py new file mode 100644 index 00000000..e0f7db6e --- /dev/null +++ b/policyengine/utils/data/caching_google_storage_client.py @@ -0,0 +1,85 @@ +from contextlib import AbstractContextManager +import diskcache +from pathlib import Path +from policyengine_core.data.dataset import atomic_write +import logging +from .simplified_google_storage_client import SimplifiedGoogleStorageClient + +logger = logging.getLogger(__name__) + + +class CachingGoogleStorageClient(AbstractContextManager): + """ + Client for downloaded resources from a google storage bucket only when the CRC + of the blob changes. + """ + + def __init__(self): + self.client = SimplifiedGoogleStorageClient() + self.cache = diskcache.Cache() + + def _data_key(self, bucket: str, key: str) -> str: + return f"{bucket}.{key}.data" + + # To absolutely 100% avoid any possible issue with file corruption or thread contention + # always replace the current target file with whatever we have cached as an atomic write. + async def download(self, bucket: str, key: str, target: Path): + """ + Atomically write the latest version of the cloud storage blob to the target path. + """ + await self.sync(bucket, key) + data = self.cache.get(self._data_key(bucket, key)) + if type(data) is bytes: + logger.debug( + f"Copying downloaded data for {bucket}, {key} to {target}" + ) + atomic_write(target, data) + return + raise Exception("Expected data for blob to be cached as bytes") + + # If the crc has changed from what we downloaded last time download it again. + # then update the CRC to whatever we actually downloaded. + async def sync(self, bucket: str, key: str) -> None: + """ + Cache the resource if the CRC has changed. + """ + logger.info(f"Syncing {bucket}, {key} to cache") + datakey = f"{bucket}.{key}.data" + crckey = f"{bucket}.{key}.crc" + + crc = self.client.crc32c(bucket, key) + if crc is None: + raise Exception(f"Unable to find {key} in bucket {bucket}") + + prev_crc = self.cache.get(crckey, default=None) + logger.debug(f"Previous crc for {bucket}, {key} was {prev_crc}") + if prev_crc == crc: + logger.info( + f"Cache exists and crc is unchanged for {bucket}, {key}." + ) + return + + [content, downloaded_crc] = await self.client.download(bucket, key) + logger.info( + f"Downloaded new version of {bucket}, {key} with crc {downloaded_crc}" + ) + + # atomic transaction to update both the data and the metadata + # at the same time. + with self.cache as c: + logger.debug(f"Updating cache...") + self.cache.set(datakey, content) + # Whatever the CRC was before we downloaded, we set the cache CRC + # to the CRC reported by the download itself to avoid race conditions. + self.cache.set(crckey, downloaded_crc) + + def clear(self): + self.cache.clear() + + def __enter__(self) -> "CachingGoogleStorageClient": + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Raise any exception triggered within the runtime context.""" + self.clear() + return None diff --git a/policyengine/utils/data/simplified_google_storage_client.py b/policyengine/utils/data/simplified_google_storage_client.py new file mode 100644 index 00000000..55b715ab --- /dev/null +++ b/policyengine/utils/data/simplified_google_storage_client.py @@ -0,0 +1,46 @@ +import asyncio +from policyengine_core.data.dataset import atomic_write +import logging +from google.cloud.storage import Client + +logger = logging.getLogger(__name__) + + +class SimplifiedGoogleStorageClient: + """ + Class separating out just the interactions with google storage required to + cache downloaded files. + + Simplifies the dependent code and unit testing. + """ + + def __init__(self): + self.client = Client() + + def crc32c(self, bucket: str, key: str) -> str | None: + """ + get the current CRC of the specified blob. None if it doesn't exist. + """ + logger.debug(f"Getting crc for {bucket}, {key}") + blob = self.client.bucket(bucket).blob(key) + blob.reload() + logger.debug(f"Crc is {blob.crc32c}") + return blob.crc32c + + async def download(self, bucket: str, key: str) -> tuple[bytes, str]: + """ + get the blob content and associated CRC from google storage. + """ + logger.debug(f"Downloading {bucket}, {key}") + blob = self.client.bucket(bucket).blob(key) + + # async implmentation as per https://github.com/googleapis/python-storage/blob/main/samples/snippets/storage_async_download.py + def download(): + return blob.download_as_bytes() + + loop = asyncio.get_running_loop() + result = await loop.run_in_executor(None, download) + # According to documentation blob.crc32c is updated as a side effect of + # downloading the content. As a result this should now be the crc of the downloaded + # content (i.e. there is not a race condition where it's getting the CRC from the cloud) + return (result, blob.crc32c) diff --git a/pyproject.toml b/pyproject.toml index e913360c..b796076f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,10 +16,11 @@ dependencies = [ "policyengine_core>=3.10", "policyengine-uk", "policyengine-us>=1.213.1", + "diskcache (>=5.6.3,<6.0.0)", + "google-cloud-storage (>=3.1.0,<4.0.0)", "microdf_python", "getpass4", - "pydantic", - "google-cloud-storage", + "pydantic" ] [project.optional-dependencies] @@ -32,6 +33,7 @@ dev = [ "yaml-changelog>=0.1.7", "itables", "build", + "pytest-asyncio>=0.26.0", ] [tool.setuptools] diff --git a/tests/utils/data/__init__.py b/tests/utils/data/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/utils/data/conftest.py b/tests/utils/data/conftest.py new file mode 100644 index 00000000..7a692420 --- /dev/null +++ b/tests/utils/data/conftest.py @@ -0,0 +1,33 @@ +import pytest +from unittest.mock import patch + + +class MockedStorageSupport: + def __init__(self, mock_simple_storage_client): + self.mock_simple_storage_client = mock_simple_storage_client + + def given_stored_data(self, data: str, crc: str): + self.mock_simple_storage_client.crc32c.return_value = crc + self.mock_simple_storage_client.download.return_value = ( + data.encode(), + crc, + ) + + def given_crc_changes_on_download( + self, data: str, initial_crc: str, download_crc: str + ): + self.mock_simple_storage_client.crc32c.return_value = initial_crc + self.mock_simple_storage_client.download.return_value = ( + data.encode(), + download_crc, + ) + + +@pytest.fixture() +def mocked_storage(): + with patch( + "policyengine.utils.data.caching_google_storage_client.SimplifiedGoogleStorageClient", + autospec=True, + ) as mock_class: + mock_instance = mock_class.return_value + yield MockedStorageSupport(mock_instance) diff --git a/tests/utils/data/test_caching_google_storage_client.py b/tests/utils/data/test_caching_google_storage_client.py new file mode 100644 index 00000000..f4850542 --- /dev/null +++ b/tests/utils/data/test_caching_google_storage_client.py @@ -0,0 +1,107 @@ +from pathlib import Path +import pytest +import tempfile +from unittest.mock import MagicMock, create_autospec, patch +from policyengine.utils.data import CachingGoogleStorageClient +from tests.utils.data.conftest import MockedStorageSupport + + +class TestCachingGoogleStorageClient: + @pytest.mark.asyncio + async def test_when_cache_miss__then_download_file( + self, mocked_storage: MockedStorageSupport + ): + with CachingGoogleStorageClient() as caching_client: + with tempfile.TemporaryDirectory() as tmpdir: + mocked_storage.given_stored_data("TEST DATA", "TEST_CRC") + await caching_client.download( + "test_bucket", "blob/path", Path(tmpdir, "output.txt") + ) + assert ( + open(Path(tmpdir, "output.txt")).readline() == "TEST DATA" + ) + + @pytest.mark.asyncio + async def test_when_cache_hit__then_use_cached_value( + self, mocked_storage: MockedStorageSupport + ): + with CachingGoogleStorageClient() as caching_client: + with tempfile.TemporaryDirectory() as tmpdir: + mocked_storage.given_stored_data( + "INITIAL TEST DATA", "TEST_CRC" + ) + await caching_client.download( + "test_bucket", "blob/path", Path(tmpdir, "output.txt") + ) + assert ( + open(Path(tmpdir, "output.txt")).readline() + == "INITIAL TEST DATA" + ) + + mocked_storage.given_stored_data( + "CRC DID NOT CHANGE SO YOU SHOULD NOT SEE THIS", "TEST_CRC" + ) + await caching_client.download( + "test_bucket", "blob/path", Path(tmpdir, "output.txt") + ) + assert ( + open(Path(tmpdir, "output.txt")).readline() + == "INITIAL TEST DATA" + ) + + @pytest.mark.asyncio + async def test_when_crc_updated__then_redownload( + self, mocked_storage: MockedStorageSupport + ): + with CachingGoogleStorageClient() as caching_client: + with tempfile.TemporaryDirectory() as tmpdir: + mocked_storage.given_stored_data( + "INITIAL TEST DATA", "TEST_CRC" + ) + await caching_client.download( + "test_bucket", "blob/path", Path(tmpdir, "output.txt") + ) + assert ( + open(Path(tmpdir, "output.txt")).readline() + == "INITIAL TEST DATA" + ) + + mocked_storage.given_stored_data( + "UPDATED_TEST_DATA", "UPDATED_TEST_CRC" + ) + await caching_client.download( + "test_bucket", "blob/path", Path(tmpdir, "output.txt") + ) + assert ( + open(Path(tmpdir, "output.txt")).readline() + == "UPDATED_TEST_DATA" + ) + + @pytest.mark.asyncio + async def test_when_crc_updated_on_download__then_store_downloaded_crc( + self, mocked_storage: MockedStorageSupport + ): + with CachingGoogleStorageClient() as caching_client: + with tempfile.TemporaryDirectory() as tmpdir: + mocked_storage.given_crc_changes_on_download( + "FINAL CONTENT", "INITIAL_CRC", "DOWNLOADED_CRC" + ) + await caching_client.download( + "test_bucket", "blob/path", Path(tmpdir, "output.txt") + ) + assert ( + open(Path(tmpdir, "output.txt")).readline() + == "FINAL CONTENT" + ) + + mocked_storage.given_stored_data( + "YOU SHOULD NOT SEE THIS BECAUSE THE CRC IS UNCHANGED FROM DOWNLOADED", + "DOWNLOADED_CRC", + ) + await caching_client.download( + "test_bucket", "blob/path", Path(tmpdir, "output.txt") + ) + assert ( + open(Path(tmpdir, "output.txt")).readline() + == "FINAL CONTENT" + ) diff --git a/tests/utils/data/test_simplified_google_storage_client.py b/tests/utils/data/test_simplified_google_storage_client.py new file mode 100644 index 00000000..aa318b0c --- /dev/null +++ b/tests/utils/data/test_simplified_google_storage_client.py @@ -0,0 +1,43 @@ +from unittest.mock import patch +import pytest +from policyengine.utils.data import SimplifiedGoogleStorageClient + + +class TestSimplifiedGoogleStorageClient: + @patch( + "policyengine.utils.data.simplified_google_storage_client.Client", + autospec=True, + ) + def test_crc32c__gets_crc(self, mock_client_class): + mock_instance = mock_client_class.return_value + bucket = mock_instance.bucket.return_value + blob = bucket.blob.return_value + + blob.crc32c = "TEST_CRC" + + client = SimplifiedGoogleStorageClient() + assert client.crc32c("bucket_name", "content.txt") == "TEST_CRC" + mock_instance.bucket.assert_called_with("bucket_name") + bucket.blob.assert_called_with("content.txt") + blob.reload.assert_called() + + @pytest.mark.asyncio + @patch( + "policyengine.utils.data.simplified_google_storage_client.Client", + autospec=True, + ) + async def test_download__downloads_content(self, mock_client_class): + mock_instance = mock_client_class.return_value + bucket = mock_instance.bucket.return_value + blob = bucket.blob.return_value + + blob.download_as_bytes.return_value = "hello, world".encode() + blob.crc32c = "TEST_CRC" + + client = SimplifiedGoogleStorageClient() + [data, crc] = await client.download("bucket", "blob.txt") + assert data == "hello, world".encode() + assert crc == "TEST_CRC" + + mock_instance.bucket.assert_called_with("bucket") + bucket.blob.assert_called_with("blob.txt")