diff --git a/_requirements/test.txt b/_requirements/test.txt index 0a77149..d69794c 100644 --- a/_requirements/test.txt +++ b/_requirements/test.txt @@ -5,3 +5,4 @@ pytest-mock pytorch-lightning >=2.0 scikit-learn >=1.0 +huggingface-hub >=0.29.0 diff --git a/src/litmodels/integrations/duplicate.py b/src/litmodels/integrations/duplicate.py new file mode 100644 index 0000000..79f6dee --- /dev/null +++ b/src/litmodels/integrations/duplicate.py @@ -0,0 +1,56 @@ +import os +import shutil +import tempfile +from pathlib import Path +from typing import Optional + +from litmodels import upload_model + + +def duplicate_hf_model( + hf_model: str, lit_model: Optional[str] = None, local_workdir: Optional[str] = None, verbose: int = 1 +) -> str: + """Downloads the model from Hugging Face and uploads it to Lightning Cloud. + + Args: + hf_model: The name of the Hugging Face model to duplicate. + lit_model: The name of the Lightning Cloud model to create. + local_workdir: + The local working directory to use for the duplication process. If not set a temp folder will be created. + verbose: Shot a progress bar for the upload. + + Returns: + The name of the duplicated model in Lightning Cloud. + """ + try: + from huggingface_hub import snapshot_download + except ModuleNotFoundError: + raise ModuleNotFoundError( + "Hugging Face Hub is not installed. Please install it with `pip install huggingface_hub`." + ) + + if not local_workdir: + local_workdir = tempfile.mkdtemp() + local_workdir = Path(local_workdir) + model_name = hf_model.replace("/", "_") + + os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" + # Download the model from Hugging Face + snapshot_download( + repo_id=hf_model, + revision="main", # Branch/tag/commit + repo_type="model", # Options: "dataset", "model", "space" + local_dir=local_workdir / model_name, # Specify to save in custom location, default is cache + local_dir_use_symlinks=True, # Use symlinks to save disk space + ignore_patterns=[".cache*"], # Exclude certain files if needed + max_workers=os.cpu_count(), # Number of parallel downloads + ) + # prune cache in the downloaded model + for path in local_workdir.rglob(".cache*"): + shutil.rmtree(path) + + # Upload the model to Lightning Cloud + if not lit_model: + lit_model = model_name + model = upload_model(name=lit_model, model=local_workdir / model_name, verbose=verbose) + return model.name diff --git a/tests/integrations/test_duplicate.py b/tests/integrations/test_duplicate.py new file mode 100644 index 0000000..e37028c --- /dev/null +++ b/tests/integrations/test_duplicate.py @@ -0,0 +1,29 @@ +import os + +import pytest +from lightning_sdk.lightning_cloud.rest_client import GridRestClient +from lightning_sdk.utils.resolve import _resolve_teamspace +from litmodels.integrations.duplicate import duplicate_hf_model + +LIT_ORG = "lightning-ai" +LIT_TEAMSPACE = "LitModels" + + +@pytest.mark.cloud() +def test_duplicate_hf_model(tmp_path): + """Verify that the HF model can be duplicated to the teamspace""" + + # model name with random hash + model_name = f"litmodels_hf_model+{os.urandom(8).hex()}" + teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None) + org_team = f"{teamspace.owner.name}/{teamspace.name}" + + duplicate_hf_model(hf_model="google/t5-efficient-tiny", lit_model=f"{org_team}/{model_name}") + + client = GridRestClient() + model = client.models_store_get_model_by_name( + project_owner_name=teamspace.owner.name, + project_name=teamspace.name, + model_name=model_name, + ) + client.models_store_delete_model(project_id=teamspace.id, model_id=model.id)