diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..fe8ca84a 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + fixed: + - Model and data versions are always available. diff --git a/policyengine/simulation.py b/policyengine/simulation.py index 93fb5a94..670cff67 100644 --- a/policyengine/simulation.py +++ b/policyengine/simulation.py @@ -140,11 +140,13 @@ def _set_data(self): ) version = self.options.data_version - file_path = download( + file_path, version = download( filepath=filename, gcs_bucket=bucket, version=version, + return_version=True, ) + self.data_version = version filename = str(Path(file_path)) else: # If it's a local file, we can't infer the version. @@ -335,16 +337,16 @@ def check_model_version(self) -> None: """ Check the package versions of the simulation against the current package versions. """ + package = f"policyengine-{self.options.country}" + try: + installed_version = metadata.version(package) + self.model_version = installed_version + except metadata.PackageNotFoundError: + raise ValueError( + f"Package {package} not found. Try running `pip install {package}`." + ) if self.options.model_version is not None: target_version = self.options.model_version - package = f"policyengine-{self.options.country}" - try: - installed_version = metadata.version(package) - self.model_version = installed_version - except metadata.PackageNotFoundError: - raise ValueError( - f"Package {package} not found. Try running `pip install {package}`." - ) if installed_version != target_version: raise ValueError( f"Package {package} version {installed_version} does not match expected version {target_version}. Try running `pip install {package}=={target_version}`." diff --git a/policyengine/utils/data/caching_google_storage_client.py b/policyengine/utils/data/caching_google_storage_client.py index 10de203a..bfb1b1ff 100644 --- a/policyengine/utils/data/caching_google_storage_client.py +++ b/policyengine/utils/data/caching_google_storage_client.py @@ -32,6 +32,7 @@ def download( key: str, target: Path, version: Optional[str] = None, + return_version: bool = False, ): """ Atomically write the latest version of the cloud storage blob to the target path. @@ -49,6 +50,8 @@ def download( f"Copying downloaded data for {bucket}, {key} to {target}" ) atomic_write(target, data) + if return_version: + return version return raise Exception("Expected data for blob to be cached as bytes") diff --git a/policyengine/utils/data_download.py b/policyengine/utils/data_download.py index fd16adcf..2bf60387 100644 --- a/policyengine/utils/data_download.py +++ b/policyengine/utils/data_download.py @@ -11,12 +11,15 @@ def download( filepath: str, gcs_bucket: str, version: Optional[str] = None, -) -> str: + return_version: bool = False, +) -> Tuple[str, Optional[str]]: logging.info("Using Google Cloud Storage for download.") - download_file_from_gcs( + version = download_file_from_gcs( bucket_name=gcs_bucket, file_name=filepath, destination_path=filepath, version=version, ) + if return_version: + return filepath, version return filepath diff --git a/policyengine/utils/google_cloud_bucket.py b/policyengine/utils/google_cloud_bucket.py index 3516b231..19f2f5ac 100644 --- a/policyengine/utils/google_cloud_bucket.py +++ b/policyengine/utils/google_cloud_bucket.py @@ -25,7 +25,7 @@ def download_file_from_gcs( file_name: str, destination_path: str, version: str = None, -) -> None: +) -> str: """ Download a file from Google Cloud Storage to a local path. @@ -35,9 +35,13 @@ def download_file_from_gcs( destination_path (str): The local path where the file will be saved. Returns: - None + version (str): The version of the file that was downloaded, if available. """ return _get_client().download( - bucket_name, file_name, Path(destination_path), version=version + bucket_name, + file_name, + Path(destination_path), + version=version, + return_version=True, ) diff --git a/tests/utils/data/test_google_cloud_bucket.py b/tests/utils/data/test_google_cloud_bucket.py index e1e2abe9..71fd6ea4 100644 --- a/tests/utils/data/test_google_cloud_bucket.py +++ b/tests/utils/data/test_google_cloud_bucket.py @@ -19,13 +19,17 @@ def setUp(self): def test_download_uses_storage_client(self, client_class): client_instance = client_class.return_value download_file_from_gcs( - "TEST_BUCKET", "TEST/FILE/NAME.TXT", "TARGET/PATH", version=None + "TEST_BUCKET", + "TEST/FILE/NAME.TXT", + "TARGET/PATH", + version=None, ) client_instance.download.assert_called_with( "TEST_BUCKET", "TEST/FILE/NAME.TXT", Path("TARGET/PATH"), version=None, + return_version=True, ) @patch(