diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..82de8d3f 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + fixed: + - google storage caching is now fully sync, not async and reenabled. diff --git a/policyengine/utils/data/caching_google_storage_client.py b/policyengine/utils/data/caching_google_storage_client.py index e0f7db6e..f00d8c08 100644 --- a/policyengine/utils/data/caching_google_storage_client.py +++ b/policyengine/utils/data/caching_google_storage_client.py @@ -23,14 +23,14 @@ def _data_key(self, bucket: str, key: str) -> str: # To absolutely 100% avoid any possible issue with file corruption or thread contention # always replace the current target file with whatever we have cached as an atomic write. - async def download(self, bucket: str, key: str, target: Path): + def download(self, bucket: str, key: str, target: Path): """ Atomically write the latest version of the cloud storage blob to the target path. """ - await self.sync(bucket, key) + self.sync(bucket, key) data = self.cache.get(self._data_key(bucket, key)) if type(data) is bytes: - logger.debug( + logger.info( f"Copying downloaded data for {bucket}, {key} to {target}" ) atomic_write(target, data) @@ -39,7 +39,7 @@ async def download(self, bucket: str, key: str, target: Path): # If the crc has changed from what we downloaded last time download it again. # then update the CRC to whatever we actually downloaded. - async def sync(self, bucket: str, key: str) -> None: + def sync(self, bucket: str, key: str) -> None: """ Cache the resource if the CRC has changed. """ @@ -59,7 +59,7 @@ async def sync(self, bucket: str, key: str) -> None: ) return - [content, downloaded_crc] = await self.client.download(bucket, key) + [content, downloaded_crc] = self.client.download(bucket, key) logger.info( f"Downloaded new version of {bucket}, {key} with crc {downloaded_crc}" ) diff --git a/policyengine/utils/data/simplified_google_storage_client.py b/policyengine/utils/data/simplified_google_storage_client.py index 55b715ab..b7c2e895 100644 --- a/policyengine/utils/data/simplified_google_storage_client.py +++ b/policyengine/utils/data/simplified_google_storage_client.py @@ -27,19 +27,14 @@ def crc32c(self, bucket: str, key: str) -> str | None: logger.debug(f"Crc is {blob.crc32c}") return blob.crc32c - async def download(self, bucket: str, key: str) -> tuple[bytes, str]: + def download(self, bucket: str, key: str) -> tuple[bytes, str]: """ get the blob content and associated CRC from google storage. """ logger.debug(f"Downloading {bucket}, {key}") blob = self.client.bucket(bucket).blob(key) - # async implmentation as per https://github.com/googleapis/python-storage/blob/main/samples/snippets/storage_async_download.py - def download(): - return blob.download_as_bytes() - - loop = asyncio.get_running_loop() - result = await loop.run_in_executor(None, download) + result = blob.download_as_bytes() # According to documentation blob.crc32c is updated as a side effect of # downloading the content. As a result this should now be the crc of the downloaded # content (i.e. there is not a race condition where it's getting the CRC from the cloud) diff --git a/policyengine/utils/data_download.py b/policyengine/utils/data_download.py index 59e4b22b..c7722173 100644 --- a/policyengine/utils/data_download.py +++ b/policyengine/utils/data_download.py @@ -27,10 +27,7 @@ def download( ) logging.info = print - if Path(filepath).exists(): - logging.info(f"File {filepath} already exists. Skipping download.") - return filepath - + # NOTE: tests will break on build if you don't default to huggingface. if data_file.huggingface_repo is not None: logging.info("Using Hugging Face for download.") try: @@ -43,6 +40,10 @@ def download( except: logging.info("Failed to download from Hugging Face.") + if Path(filepath).exists(): + logging.info(f"File {filepath} already exists. Skipping download.") + return filepath + if data_file.gcs_bucket is not None: logging.info("Using Google Cloud Storage for download.") download_file_from_gcs( diff --git a/policyengine/utils/google_cloud_bucket.py b/policyengine/utils/google_cloud_bucket.py index 14dc9031..f080c21b 100644 --- a/policyengine/utils/google_cloud_bucket.py +++ b/policyengine/utils/google_cloud_bucket.py @@ -1,3 +1,23 @@ +from .data.caching_google_storage_client import CachingGoogleStorageClient +import asyncio +from pathlib import Path + +_caching_client: CachingGoogleStorageClient | None = None + + +def _get_client(): + global _caching_client + if _caching_client is not None: + return _caching_client + _caching_client = CachingGoogleStorageClient() + return _caching_client + + +def _clear_client(): + global _caching_client + _caching_client = None + + def download_file_from_gcs( bucket_name: str, file_name: str, destination_path: str ) -> None: @@ -12,18 +32,4 @@ def download_file_from_gcs( Returns: None """ - from google.cloud import storage - - # Initialize a client - client = storage.Client() - - # Get the bucket - bucket = client.bucket(bucket_name) - - # Create a blob object from the file name - blob = bucket.blob(file_name) - - # Download the file to a local path - blob.download_to_filename(destination_path) - - return destination_path + _get_client().download(bucket_name, file_name, Path(destination_path)) diff --git a/tests/utils/data/test_caching_google_storage_client.py b/tests/utils/data/test_caching_google_storage_client.py index f4850542..99b8a687 100644 --- a/tests/utils/data/test_caching_google_storage_client.py +++ b/tests/utils/data/test_caching_google_storage_client.py @@ -7,22 +7,20 @@ class TestCachingGoogleStorageClient: - @pytest.mark.asyncio - async def test_when_cache_miss__then_download_file( + def test_when_cache_miss__then_download_file( self, mocked_storage: MockedStorageSupport ): with CachingGoogleStorageClient() as caching_client: with tempfile.TemporaryDirectory() as tmpdir: mocked_storage.given_stored_data("TEST DATA", "TEST_CRC") - await caching_client.download( + caching_client.download( "test_bucket", "blob/path", Path(tmpdir, "output.txt") ) assert ( open(Path(tmpdir, "output.txt")).readline() == "TEST DATA" ) - @pytest.mark.asyncio - async def test_when_cache_hit__then_use_cached_value( + def test_when_cache_hit__then_use_cached_value( self, mocked_storage: MockedStorageSupport ): with CachingGoogleStorageClient() as caching_client: @@ -30,7 +28,7 @@ async def test_when_cache_hit__then_use_cached_value( mocked_storage.given_stored_data( "INITIAL TEST DATA", "TEST_CRC" ) - await caching_client.download( + caching_client.download( "test_bucket", "blob/path", Path(tmpdir, "output.txt") ) assert ( @@ -41,7 +39,7 @@ async def test_when_cache_hit__then_use_cached_value( mocked_storage.given_stored_data( "CRC DID NOT CHANGE SO YOU SHOULD NOT SEE THIS", "TEST_CRC" ) - await caching_client.download( + caching_client.download( "test_bucket", "blob/path", Path(tmpdir, "output.txt") ) assert ( @@ -49,8 +47,7 @@ async def test_when_cache_hit__then_use_cached_value( == "INITIAL TEST DATA" ) - @pytest.mark.asyncio - async def test_when_crc_updated__then_redownload( + def test_when_crc_updated__then_redownload( self, mocked_storage: MockedStorageSupport ): with CachingGoogleStorageClient() as caching_client: @@ -58,7 +55,7 @@ async def test_when_crc_updated__then_redownload( mocked_storage.given_stored_data( "INITIAL TEST DATA", "TEST_CRC" ) - await caching_client.download( + caching_client.download( "test_bucket", "blob/path", Path(tmpdir, "output.txt") ) assert ( @@ -69,7 +66,7 @@ async def test_when_crc_updated__then_redownload( mocked_storage.given_stored_data( "UPDATED_TEST_DATA", "UPDATED_TEST_CRC" ) - await caching_client.download( + caching_client.download( "test_bucket", "blob/path", Path(tmpdir, "output.txt") ) assert ( @@ -77,8 +74,7 @@ async def test_when_crc_updated__then_redownload( == "UPDATED_TEST_DATA" ) - @pytest.mark.asyncio - async def test_when_crc_updated_on_download__then_store_downloaded_crc( + def test_when_crc_updated_on_download__then_store_downloaded_crc( self, mocked_storage: MockedStorageSupport ): with CachingGoogleStorageClient() as caching_client: @@ -86,7 +82,7 @@ async def test_when_crc_updated_on_download__then_store_downloaded_crc( mocked_storage.given_crc_changes_on_download( "FINAL CONTENT", "INITIAL_CRC", "DOWNLOADED_CRC" ) - await caching_client.download( + caching_client.download( "test_bucket", "blob/path", Path(tmpdir, "output.txt") ) assert ( @@ -98,7 +94,7 @@ async def test_when_crc_updated_on_download__then_store_downloaded_crc( "YOU SHOULD NOT SEE THIS BECAUSE THE CRC IS UNCHANGED FROM DOWNLOADED", "DOWNLOADED_CRC", ) - await caching_client.download( + caching_client.download( "test_bucket", "blob/path", Path(tmpdir, "output.txt") ) assert ( diff --git a/tests/utils/data/test_google_cloud_bucket.py b/tests/utils/data/test_google_cloud_bucket.py new file mode 100644 index 00000000..c141c0f9 --- /dev/null +++ b/tests/utils/data/test_google_cloud_bucket.py @@ -0,0 +1,39 @@ +from unittest import TestCase +from unittest.mock import patch +import pytest +from pathlib import Path +from policyengine.utils.google_cloud_bucket import ( + download_file_from_gcs, + _clear_client, +) + + +class TestGoogleCloudBucket(TestCase): + def setUp(self): + _clear_client() + + @patch( + "policyengine.utils.google_cloud_bucket.CachingGoogleStorageClient", + autospec=True, + ) + def test_download_uses_storage_client(self, client_class): + client_instance = client_class.return_value + download_file_from_gcs( + "TEST_BUCKET", "TEST/FILE/NAME.TXT", "TARGET/PATH" + ) + client_instance.download.assert_called_with( + "TEST_BUCKET", "TEST/FILE/NAME.TXT", Path("TARGET/PATH") + ) + + @patch( + "policyengine.utils.google_cloud_bucket.CachingGoogleStorageClient", + autospec=True, + ) + def test_download_only_creates_client_once(self, client_class): + download_file_from_gcs( + "TEST_BUCKET", "TEST/FILE/NAME.TXT", "TARGET/PATH" + ) + download_file_from_gcs( + "TEST_BUCKET", "TEST/FILE/NAME.TXT", "ANOTHER/PATH" + ) + client_class.assert_called_once() diff --git a/tests/utils/data/test_simplified_google_storage_client.py b/tests/utils/data/test_simplified_google_storage_client.py index aa318b0c..fc692f02 100644 --- a/tests/utils/data/test_simplified_google_storage_client.py +++ b/tests/utils/data/test_simplified_google_storage_client.py @@ -21,12 +21,11 @@ def test_crc32c__gets_crc(self, mock_client_class): bucket.blob.assert_called_with("content.txt") blob.reload.assert_called() - @pytest.mark.asyncio @patch( "policyengine.utils.data.simplified_google_storage_client.Client", autospec=True, ) - async def test_download__downloads_content(self, mock_client_class): + def test_download__downloads_content(self, mock_client_class): mock_instance = mock_client_class.return_value bucket = mock_instance.bucket.return_value blob = bucket.blob.return_value @@ -35,7 +34,7 @@ async def test_download__downloads_content(self, mock_client_class): blob.crc32c = "TEST_CRC" client = SimplifiedGoogleStorageClient() - [data, crc] = await client.download("bucket", "blob.txt") + [data, crc] = client.download("bucket", "blob.txt") assert data == "hello, world".encode() assert crc == "TEST_CRC"