Skip to content

Commit 966c957

Browse files
committed
duplicating HF model
1 parent 9c660e7 commit 966c957

File tree

2 files changed

+87
-0
lines changed

2 files changed

+87
-0
lines changed
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
import os
2+
import shutil
3+
import tempfile
4+
from pathlib import Path
5+
from typing import Optional
6+
7+
from litmodels import upload_model
8+
9+
10+
def duplicate_hf_model(
11+
hf_model: str, lit_model: Optional[str] = None, local_workdir: Optional[str] = None, verbose: int = 1
12+
) -> str:
13+
"""Downloads the model from Hugging Face and uploads it to Lightning Cloud.
14+
15+
Args:
16+
hf_model: The name of the Hugging Face model to duplicate.
17+
lit_model: The name of the Lightning Cloud model to create.
18+
local_workdir:
19+
The local working directory to use for the duplication process. If not set a temp folder will be created.
20+
verbose: Shot a progress bar for the upload.
21+
22+
Returns:
23+
The name of the duplicated model in Lightning Cloud.
24+
25+
>>> duplicate_hf_model("bert-base-uncased", lit_model="lightning-ai/jirka/bert-base-uncased", local_workdir=".")
26+
"""
27+
try:
28+
from huggingface_hub import snapshot_download
29+
except ModuleNotFoundError:
30+
raise ModuleNotFoundError(
31+
"Hugging Face Hub is not installed. Please install it with `pip install huggingface_hub`."
32+
)
33+
34+
if not local_workdir:
35+
local_workdir = tempfile.mkdtemp()
36+
local_workdir = Path(local_workdir)
37+
model_name = hf_model.replace("/", "_")
38+
39+
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
40+
# Download the model from Hugging Face
41+
snapshot_download(
42+
repo_id=hf_model,
43+
revision="main", # Branch/tag/commit
44+
repo_type="model", # Options: "dataset", "model", "space"
45+
local_dir=local_workdir / model_name, # Specify to save in custom location, default is cache
46+
local_dir_use_symlinks=True, # Use symlinks to save disk space
47+
ignore_patterns=[".cache*"], # Exclude certain files if needed
48+
max_workers=os.cpu_count(), # Number of parallel downloads
49+
)
50+
# prune cache in the downloaded model
51+
for path in local_workdir.rglob(".cache*"):
52+
shutil.rmtree(path)
53+
54+
# Upload the model to Lightning Cloud
55+
if not lit_model:
56+
lit_model = model_name
57+
model = upload_model(name=lit_model, model=local_workdir / model_name, verbose=verbose)
58+
return f"{model.teamspace}/{model.name}"
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import os
2+
3+
import pytest
4+
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
5+
from lightning_sdk.utils.resolve import _resolve_teamspace
6+
from litmodels.integrations.duplicate import duplicate_hf_model
7+
8+
LIT_ORG = "lightning-ai"
9+
LIT_TEAMSPACE = "LitModels"
10+
11+
12+
@pytest.mark.cloud()
13+
def test_duplicate_hf_model(tmp_path):
14+
"""Verify that the HF model can be duplicated to the teamspace"""
15+
16+
# model name with random hash
17+
model_name = f"litmodels_hf_model+{os.urandom(8).hex()}"
18+
teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None)
19+
org_team = f"{teamspace.owner.name}/{teamspace.name}"
20+
21+
duplicate_hf_model(hf_model="google/t5-efficient-tiny", lit_model=f"{org_team}/{model_name}")
22+
23+
client = GridRestClient()
24+
model = client.models_store_get_model_by_name(
25+
project_owner_name=teamspace.owner.name,
26+
project_name=teamspace.name,
27+
model_name=model_name,
28+
)
29+
client.models_store_delete_model(project_id=teamspace.id, model_id=model.id)

0 commit comments

Comments
 (0)