Skip to content

Commit 59ce177

Browse files
author
Michael Smit
committed
Add CachingGoogleStorageClient
Related to PolicyEngine/issues#350 This commit adds, but does not yet use, CachingGoogleStorageClient which is a class used to monitor a remote file in google storage for changes caching the result locally to disk.
1 parent 03640e8 commit 59ce177

File tree

8 files changed

+323
-2
lines changed

8 files changed

+323
-2
lines changed

changelog_entry.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
- bump: patch
2+
changes:
3+
added:
4+
- new class CachingGoogleStorageClient for locally caching gs files to disk.
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .caching_google_storage_client import CachingGoogleStorageClient
2+
from .simplified_google_storage_client import SimplifiedGoogleStorageClient
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
from contextlib import AbstractContextManager
2+
import diskcache
3+
from pathlib import Path
4+
from policyengine_core.data.dataset import atomic_write
5+
import logging
6+
from .simplified_google_storage_client import SimplifiedGoogleStorageClient
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class CachingGoogleStorageClient(AbstractContextManager):
12+
"""
13+
Client for downloaded resources from a google storage bucket only when the CRC
14+
of the blob changes.
15+
"""
16+
17+
def __init__(self):
18+
self.client = SimplifiedGoogleStorageClient()
19+
self.cache = diskcache.Cache()
20+
21+
def _data_key(self, bucket: str, key: str) -> str:
22+
return f"{bucket}.{key}.data"
23+
24+
# To absolutely 100% avoid any possible issue with file corruption or thread contention
25+
# always replace the current target file with whatever we have cached as an atomic write.
26+
async def download(self, bucket: str, key: str, target: Path):
27+
"""
28+
Atomically write the latest version of the cloud storage blob to the target path.
29+
"""
30+
await self.sync(bucket, key)
31+
data = self.cache.get(self._data_key(bucket, key))
32+
if type(data) is bytes:
33+
logger.debug(
34+
f"Copying downloaded data for {bucket}, {key} to {target}"
35+
)
36+
atomic_write(target, data)
37+
return
38+
raise Exception("Expected data for blob to be cached as bytes")
39+
40+
# If the crc has changed from what we downloaded last time download it again.
41+
# then update the CRC to whatever we actually downloaded.
42+
async def sync(self, bucket: str, key: str) -> None:
43+
"""
44+
Cache the resource if the CRC has changed.
45+
"""
46+
logger.info(f"Syncing {bucket}, {key} to cache")
47+
datakey = f"{bucket}.{key}.data"
48+
crckey = f"{bucket}.{key}.crc"
49+
50+
crc = self.client.crc32c(bucket, key)
51+
if crc is None:
52+
raise Exception(f"Unable to find {key} in bucket {bucket}")
53+
54+
prev_crc = self.cache.get(crckey, default=None)
55+
logger.debug(f"Previous crc for {bucket}, {key} was {prev_crc}")
56+
if prev_crc == crc:
57+
logger.info(
58+
f"Cache exists and crc is unchanged for {bucket}, {key}."
59+
)
60+
return
61+
62+
[content, downloaded_crc] = await self.client.download(bucket, key)
63+
logger.info(
64+
f"Downloaded new version of {bucket}, {key} with crc {downloaded_crc}"
65+
)
66+
67+
# atomic transaction to update both the data and the metadata
68+
# at the same time.
69+
with self.cache as c:
70+
logger.debug(f"Updating cache...")
71+
self.cache.set(datakey, content)
72+
# Whatever the CRC was before we downloaded, we set the cache CRC
73+
# to the CRC reported by the download itself to avoid race conditions.
74+
self.cache.set(crckey, downloaded_crc)
75+
76+
def clear(self):
77+
self.cache.clear()
78+
79+
def __enter__(self) -> "CachingGoogleStorageClient":
80+
return self
81+
82+
def __exit__(self, exc_type, exc_value, traceback):
83+
"""Raise any exception triggered within the runtime context."""
84+
self.clear()
85+
return None
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import asyncio
2+
from policyengine_core.data.dataset import atomic_write
3+
import logging
4+
from google.cloud.storage import Client
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
class SimplifiedGoogleStorageClient:
10+
"""
11+
Class separating out just the interactions with google storage required to
12+
cache downloaded files.
13+
14+
Simplifies the dependent code and unit testing.
15+
"""
16+
17+
def __init__(self):
18+
self.client = Client()
19+
20+
def crc32c(self, bucket: str, key: str) -> str | None:
21+
"""
22+
get the current CRC of the specified blob. None if it doesn't exist.
23+
"""
24+
logger.debug(f"Getting crc for {bucket}, {key}")
25+
blob = self.client.bucket(bucket).blob(key)
26+
blob.reload()
27+
logger.debug(f"Crc is {blob.crc32c}")
28+
return blob.crc32c
29+
30+
async def download(self, bucket: str, key: str) -> tuple[bytes, str]:
31+
"""
32+
get the blob content and associated CRC from google storage.
33+
"""
34+
logger.debug(f"Downloading {bucket}, {key}")
35+
blob = self.client.bucket(bucket).blob(key)
36+
37+
# async implmentation as per https://github.com/googleapis/python-storage/blob/main/samples/snippets/storage_async_download.py
38+
def download():
39+
return blob.download_as_bytes()
40+
41+
loop = asyncio.get_running_loop()
42+
result = await loop.run_in_executor(None, download)
43+
# According to documentation blob.crc32c is updated as a side effect of
44+
# downloading the content. As a result this should now be the crc of the downloaded
45+
# content (i.e. there is not a race condition where it's getting the CRC from the cloud)
46+
return (result, blob.crc32c)

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@ dependencies = [
1616
"policyengine_core>=3.10",
1717
"policyengine-uk",
1818
"policyengine-us>=1.213.1",
19+
"diskcache (>=5.6.3,<6.0.0)",
20+
"google-cloud-storage (>=3.1.0,<4.0.0)",
1921
"microdf_python",
2022
"getpass4",
21-
"pydantic",
22-
"google-cloud-storage",
23+
"pydantic"
2324
]
2425

2526
[project.optional-dependencies]
@@ -32,6 +33,7 @@ dev = [
3233
"yaml-changelog>=0.1.7",
3334
"itables",
3435
"build",
36+
"pytest-asyncio>=0.26.0",
3537
]
3638

3739
[tool.setuptools]

tests/utils/data/__init__.py

Whitespace-only changes.
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
from pathlib import Path
2+
import pytest
3+
import tempfile
4+
from unittest.mock import MagicMock, create_autospec, patch
5+
from policyengine.utils.data import CachingGoogleStorageClient
6+
7+
8+
class MockedStorageSupport:
9+
def __init__(self, mock_simple_storage_client):
10+
self.mock_simple_storage_client = mock_simple_storage_client
11+
12+
def given_stored_data(self, data: str, crc: str):
13+
self.mock_simple_storage_client.crc32c.return_value = crc
14+
self.mock_simple_storage_client.download.return_value = (
15+
data.encode(),
16+
crc,
17+
)
18+
19+
def given_crc_changes_on_download(
20+
self, data: str, initial_crc: str, download_crc: str
21+
):
22+
self.mock_simple_storage_client.crc32c.return_value = initial_crc
23+
self.mock_simple_storage_client.download.return_value = (
24+
data.encode(),
25+
download_crc,
26+
)
27+
28+
29+
@pytest.fixture()
30+
def mocked_storage():
31+
with patch(
32+
"policyengine.utils.data.caching_google_storage_client.SimplifiedGoogleStorageClient",
33+
autospec=True,
34+
) as mock_class:
35+
mock_instance = mock_class.return_value
36+
yield MockedStorageSupport(mock_instance)
37+
38+
39+
class TestCachingGoogleStorageClient:
40+
@pytest.mark.asyncio
41+
async def test_when_cache_miss__then_download_file(
42+
self, mocked_storage: MockedStorageSupport
43+
):
44+
with CachingGoogleStorageClient() as caching_client:
45+
with tempfile.TemporaryDirectory() as tmpdir:
46+
mocked_storage.given_stored_data("TEST DATA", "TEST_CRC")
47+
await caching_client.download(
48+
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
49+
)
50+
assert (
51+
open(Path(tmpdir, "output.txt")).readline() == "TEST DATA"
52+
)
53+
54+
@pytest.mark.asyncio
55+
async def test_when_cache_hit__then_use_cached_value(
56+
self, mocked_storage: MockedStorageSupport
57+
):
58+
with CachingGoogleStorageClient() as caching_client:
59+
with tempfile.TemporaryDirectory() as tmpdir:
60+
mocked_storage.given_stored_data(
61+
"INITIAL TEST DATA", "TEST_CRC"
62+
)
63+
await caching_client.download(
64+
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
65+
)
66+
assert (
67+
open(Path(tmpdir, "output.txt")).readline()
68+
== "INITIAL TEST DATA"
69+
)
70+
71+
mocked_storage.given_stored_data(
72+
"CRC DID NOT CHANGE SO YOU SHOULD NOT SEE THIS", "TEST_CRC"
73+
)
74+
await caching_client.download(
75+
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
76+
)
77+
assert (
78+
open(Path(tmpdir, "output.txt")).readline()
79+
== "INITIAL TEST DATA"
80+
)
81+
82+
@pytest.mark.asyncio
83+
async def test_when_crc_updated__then_redownload(
84+
self, mocked_storage: MockedStorageSupport
85+
):
86+
with CachingGoogleStorageClient() as caching_client:
87+
with tempfile.TemporaryDirectory() as tmpdir:
88+
mocked_storage.given_stored_data(
89+
"INITIAL TEST DATA", "TEST_CRC"
90+
)
91+
await caching_client.download(
92+
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
93+
)
94+
assert (
95+
open(Path(tmpdir, "output.txt")).readline()
96+
== "INITIAL TEST DATA"
97+
)
98+
99+
mocked_storage.given_stored_data(
100+
"UPDATED_TEST_DATA", "UPDATED_TEST_CRC"
101+
)
102+
await caching_client.download(
103+
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
104+
)
105+
assert (
106+
open(Path(tmpdir, "output.txt")).readline()
107+
== "UPDATED_TEST_DATA"
108+
)
109+
110+
@pytest.mark.asyncio
111+
async def test_when_crc_updated_on_download__then_store_downloaded_crc(
112+
self, mocked_storage: MockedStorageSupport
113+
):
114+
with CachingGoogleStorageClient() as caching_client:
115+
with tempfile.TemporaryDirectory() as tmpdir:
116+
mocked_storage.given_crc_changes_on_download(
117+
"FINAL CONTENT", "INITIAL_CRC", "DOWNLOADED_CRC"
118+
)
119+
await caching_client.download(
120+
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
121+
)
122+
assert (
123+
open(Path(tmpdir, "output.txt")).readline()
124+
== "FINAL CONTENT"
125+
)
126+
127+
mocked_storage.given_stored_data(
128+
"YOU SHOULD NOT SEE THIS BECAUSE THE CRC IS UNCHANGED FROM DOWNLOADED",
129+
"DOWNLOADED_CRC",
130+
)
131+
await caching_client.download(
132+
"test_bucket", "blob/path", Path(tmpdir, "output.txt")
133+
)
134+
assert (
135+
open(Path(tmpdir, "output.txt")).readline()
136+
== "FINAL CONTENT"
137+
)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from unittest.mock import patch
2+
import pytest
3+
import tempfile
4+
from pathlib import Path
5+
from policyengine.utils.data import SimplifiedGoogleStorageClient
6+
7+
8+
class TestSimplifiedGoogleStorageClient:
9+
@patch(
10+
"policyengine.utils.data.simplified_google_storage_client.Client",
11+
autospec=True,
12+
)
13+
def test_crc32c__gets_crc(self, mock_client_class):
14+
mock_instance = mock_client_class.return_value
15+
bucket = mock_instance.bucket.return_value
16+
blob = bucket.blob.return_value
17+
18+
blob.crc32c = "TEST_CRC"
19+
20+
client = SimplifiedGoogleStorageClient()
21+
assert client.crc32c("bucket_name", "content.txt") == "TEST_CRC"
22+
mock_instance.bucket.assert_called_with("bucket_name")
23+
bucket.blob.assert_called_with("content.txt")
24+
blob.reload.assert_called()
25+
26+
@pytest.mark.asyncio
27+
@patch(
28+
"policyengine.utils.data.simplified_google_storage_client.Client",
29+
autospec=True,
30+
)
31+
async def test_download__downloads_content(self, mock_client_class):
32+
mock_instance = mock_client_class.return_value
33+
bucket = mock_instance.bucket.return_value
34+
blob = bucket.blob.return_value
35+
36+
blob.download_as_bytes.return_value = "hello, world".encode()
37+
blob.crc32c = "TEST_CRC"
38+
39+
client = SimplifiedGoogleStorageClient()
40+
[data, crc] = await client.download("bucket", "blob.txt")
41+
assert data == "hello, world".encode()
42+
assert crc == "TEST_CRC"
43+
44+
mock_instance.bucket.assert_called_with("bucket")
45+
bucket.blob.assert_called_with("blob.txt")

0 commit comments

Comments
 (0)