Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _requirements/test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ pytest-mock

pytorch-lightning >=2.0
scikit-learn >=1.0
huggingface-hub >=0.29.0
56 changes: 56 additions & 0 deletions src/litmodels/integrations/duplicate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
import shutil
import tempfile
from pathlib import Path
from typing import Optional

from litmodels import upload_model


def duplicate_hf_model(
hf_model: str, lit_model: Optional[str] = None, local_workdir: Optional[str] = None, verbose: int = 1
) -> str:
"""Downloads the model from Hugging Face and uploads it to Lightning Cloud.

Args:
hf_model: The name of the Hugging Face model to duplicate.
lit_model: The name of the Lightning Cloud model to create.
local_workdir:
The local working directory to use for the duplication process. If not set a temp folder will be created.
verbose: Shot a progress bar for the upload.

Returns:
The name of the duplicated model in Lightning Cloud.
"""
try:
from huggingface_hub import snapshot_download
except ModuleNotFoundError:
raise ModuleNotFoundError(
"Hugging Face Hub is not installed. Please install it with `pip install huggingface_hub`."
)

if not local_workdir:
local_workdir = tempfile.mkdtemp()
local_workdir = Path(local_workdir)
model_name = hf_model.replace("/", "_")

os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Download the model from Hugging Face
snapshot_download(
repo_id=hf_model,
revision="main", # Branch/tag/commit
repo_type="model", # Options: "dataset", "model", "space"
local_dir=local_workdir / model_name, # Specify to save in custom location, default is cache
local_dir_use_symlinks=True, # Use symlinks to save disk space
ignore_patterns=[".cache*"], # Exclude certain files if needed
max_workers=os.cpu_count(), # Number of parallel downloads
)
# prune cache in the downloaded model
for path in local_workdir.rglob(".cache*"):
shutil.rmtree(path)

# Upload the model to Lightning Cloud
if not lit_model:
lit_model = model_name
model = upload_model(name=lit_model, model=local_workdir / model_name, verbose=verbose)
return model.name
29 changes: 29 additions & 0 deletions tests/integrations/test_duplicate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import os

import pytest
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
from lightning_sdk.utils.resolve import _resolve_teamspace
from litmodels.integrations.duplicate import duplicate_hf_model

LIT_ORG = "lightning-ai"
LIT_TEAMSPACE = "LitModels"


@pytest.mark.cloud()
def test_duplicate_hf_model(tmp_path):
"""Verify that the HF model can be duplicated to the teamspace"""

# model name with random hash
model_name = f"litmodels_hf_model+{os.urandom(8).hex()}"
teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None)
org_team = f"{teamspace.owner.name}/{teamspace.name}"

duplicate_hf_model(hf_model="google/t5-efficient-tiny", lit_model=f"{org_team}/{model_name}")

client = GridRestClient()
model = client.models_store_get_model_by_name(
project_owner_name=teamspace.owner.name,
project_name=teamspace.name,
model_name=model_name,
)
client.models_store_delete_model(project_id=teamspace.id, model_id=model.id)
Loading