diff --git a/changelog_entry.yaml b/changelog_entry.yaml index e69de29b..e4ec408d 100644 --- a/changelog_entry.yaml +++ b/changelog_entry.yaml @@ -0,0 +1,4 @@ +- bump: patch + changes: + fixed: + - Gracefully handle data download 429s. diff --git a/policyengine/utils/huggingface.py b/policyengine/utils/huggingface.py index 3c8f2e68..09684074 100644 --- a/policyengine/utils/huggingface.py +++ b/policyengine/utils/huggingface.py @@ -1,6 +1,7 @@ from huggingface_hub import hf_hub_download import os from getpass import getpass +import time def download( @@ -16,12 +17,33 @@ def download( ) # Optionally store in env for subsequent calls in same session os.environ["HUGGING_FACE_TOKEN"] = token - - return hf_hub_download( - repo_id=repo, - repo_type="model", - filename=repo_filename, - local_dir=local_folder, - revision=version, - token=token, - ) + try: + result = hf_hub_download( + repo_id=repo, + repo_type="model", + filename=repo_filename, + local_dir=local_folder, + revision=version, + token=token, + ) + except: + # In the case of a 429 Too Many Requests error, retry up to 5 times, waiting 30 seconds + # between attempts + for i in range(5): + try: + result = hf_hub_download( + repo_id=repo, + repo_type="model", + filename=repo_filename, + local_dir=local_folder, + revision=version, + token=token, + ) + break + except Exception as e: + if i == 4: + raise e + print(f"Error downloading {repo_filename} from {repo}: {e}") + print("Retrying in 30 seconds...") + time.sleep(30) + return result