Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog_entry.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
- bump: patch
changes:
fixed:
- Model and data versions are always available.
20 changes: 11 additions & 9 deletions policyengine/simulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,13 @@ def _set_data(self):
)
version = self.options.data_version

file_path = download(
file_path, version = download(
filepath=filename,
gcs_bucket=bucket,
version=version,
return_version=True,
)
self.data_version = version
filename = str(Path(file_path))
else:
# If it's a local file, we can't infer the version.
Expand Down Expand Up @@ -335,16 +337,16 @@ def check_model_version(self) -> None:
"""
Check the package versions of the simulation against the current package versions.
"""
package = f"policyengine-{self.options.country}"
try:
installed_version = metadata.version(package)
self.model_version = installed_version
except metadata.PackageNotFoundError:
raise ValueError(
f"Package {package} not found. Try running `pip install {package}`."
)
if self.options.model_version is not None:
target_version = self.options.model_version
package = f"policyengine-{self.options.country}"
try:
installed_version = metadata.version(package)
self.model_version = installed_version
except metadata.PackageNotFoundError:
raise ValueError(
f"Package {package} not found. Try running `pip install {package}`."
)
if installed_version != target_version:
raise ValueError(
f"Package {package} version {installed_version} does not match expected version {target_version}. Try running `pip install {package}=={target_version}`."
Expand Down
3 changes: 3 additions & 0 deletions policyengine/utils/data/caching_google_storage_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def download(
key: str,
target: Path,
version: Optional[str] = None,
return_version: bool = False,
):
"""
Atomically write the latest version of the cloud storage blob to the target path.
Expand All @@ -49,6 +50,8 @@ def download(
f"Copying downloaded data for {bucket}, {key} to {target}"
)
atomic_write(target, data)
if return_version:
return version
return
raise Exception("Expected data for blob to be cached as bytes")

Expand Down
7 changes: 5 additions & 2 deletions policyengine/utils/data_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,15 @@ def download(
filepath: str,
gcs_bucket: str,
version: Optional[str] = None,
) -> str:
return_version: bool = False,
) -> Tuple[str, Optional[str]]:
logging.info("Using Google Cloud Storage for download.")
download_file_from_gcs(
version = download_file_from_gcs(
bucket_name=gcs_bucket,
file_name=filepath,
destination_path=filepath,
version=version,
)
if return_version:
return filepath, version
return filepath
10 changes: 7 additions & 3 deletions policyengine/utils/google_cloud_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def download_file_from_gcs(
file_name: str,
destination_path: str,
version: str = None,
) -> None:
) -> str:
"""
Download a file from Google Cloud Storage to a local path.

Expand All @@ -35,9 +35,13 @@ def download_file_from_gcs(
destination_path (str): The local path where the file will be saved.

Returns:
None
version (str): The version of the file that was downloaded, if available.
"""

return _get_client().download(
bucket_name, file_name, Path(destination_path), version=version
bucket_name,
file_name,
Path(destination_path),
version=version,
return_version=True,
)
6 changes: 5 additions & 1 deletion tests/utils/data/test_google_cloud_bucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,17 @@ def setUp(self):
def test_download_uses_storage_client(self, client_class):
client_instance = client_class.return_value
download_file_from_gcs(
"TEST_BUCKET", "TEST/FILE/NAME.TXT", "TARGET/PATH", version=None
"TEST_BUCKET",
"TEST/FILE/NAME.TXT",
"TARGET/PATH",
version=None,
)
client_instance.download.assert_called_with(
"TEST_BUCKET",
"TEST/FILE/NAME.TXT",
Path("TARGET/PATH"),
version=None,
return_version=True,
)

@patch(
Expand Down
Loading