diff --git a/README.md b/README.md index c5ac4d3..52021e8 100644 --- a/README.md +++ b/README.md @@ -16,13 +16,21 @@ We wrote a tutorial on how to use 🤗 Hub and Stable-Baselines3 [here](https:// If you use **Colab or a Virtual/Screenless Machine**, you can check Case 3 and Case 4. ### Case 1: I want to download a model from the Hub + +You will need to set the `TRUST_REMOTE_CODE` environment variable to `True` to allow the use of `pickle.load()`: + ```python +import os import gymnasium as gym from huggingface_sb3 import load_from_hub from stable_baselines3 import PPO from stable_baselines3.common.evaluation import evaluate_policy +# Allow the use of `pickle.load()` when downloading model from the hub +# Please make sure that the organization from which you download can be trusted +os.environ["TRUST_REMOTE_CODE"] = "True" + # Retrieve the model from the hub ## repo_id = id of the model repository from the Hugging Face Hub (repo_id = {organization}/{repo_name}) ## filename = name of the model zip file from the repository diff --git a/huggingface_sb3/load_from_hub.py b/huggingface_sb3/load_from_hub.py index cc426b1..483aa37 100644 --- a/huggingface_sb3/load_from_hub.py +++ b/huggingface_sb3/load_from_hub.py @@ -1,3 +1,22 @@ +import os + + +# Vendored from distutils.util +def strtobool(val: str) -> bool: + """Convert a string representation of truth to true (1) or false (0). + + True values are 'y', 'yes', 't', 'true', 'on', and '1'; + False values are 'n', 'no', 'f', 'false', 'off', and '0'. + Raises ValueError if 'val' is anything else. + """ + val = val.lower() + if val in {"y", "yes", "t", "true", "on", "1"}: + return 1 + if val in {"n", "no", "f", "false", "off", "0"}: + return 0 + raise ValueError(f"Invalid truth value {val!r}") + + def load_from_hub(repo_id: str, filename: str) -> str: """ Download a model from Hugging Face Hub. @@ -12,6 +31,17 @@ def load_from_hub(repo_id: str, filename: str) -> str: "See https://pypi.org/project/huggingface-hub/ for installation." ) + # Copied from https://github.com/huggingface/transformers/pull/27776 + if not strtobool(os.environ.get("TRUST_REMOTE_CODE", "False")): + raise ValueError( + "You are about to download a model from the HF hub that will be loaded using `pickle.load`. " + "`pickle.load` is insecure and will execute arbitrary code that is " + "potentially malicious. It's recommended to never unpickle data that could have come from an " + "untrusted source, or that could have been tampered with. If you trust the pickle " + "data and decided to use it, you can set the environment variable " + "`TRUST_REMOTE_CODE` to `True` to allow it." + ) + # Get the model from the Hub, download and cache the model on your local disk downloaded_model_file = hf_hub_download( repo_id=repo_id, diff --git a/tests/test_load_from_hub.py b/tests/test_load_from_hub.py index 1b70af0..0e15e83 100644 --- a/tests/test_load_from_hub.py +++ b/tests/test_load_from_hub.py @@ -1,3 +1,4 @@ +import os import sys import gymnasium as gym @@ -6,6 +7,9 @@ from huggingface_sb3 import EnvironmentName, ModelName, ModelRepoId, load_from_hub +# Test models from sb3 organization can be trusted +os.environ["TRUST_REMOTE_CODE"] = "True" + def test_load_from_hub_with_naming_scheme_utils(): # Retrieve the model from the hub