Skip to content

Commit 9c479d7

Browse files
authored
feat: metadata for tracing integrations (#85)
* metadata for tracing integrations * tests + mocks
1 parent de79549 commit 9c479d7

File tree

11 files changed

+132
-41
lines changed

11 files changed

+132
-41
lines changed

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# NOTE: once we add more dependencies, consider update dependabot to check for updates
22

3-
lightning-sdk >=0.2.7
3+
lightning-sdk >=0.2.9
44
lightning-utilities
55
joblib

src/litmodels/integrations/checkpoints.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import queue
23
import threading
34
from abc import ABC
@@ -82,9 +83,9 @@ def _worker_loop(self) -> None:
8283
break
8384
action, detail = task
8485
if action == Action.UPLOAD:
85-
registry_name, filepath = detail
86+
registry_name, filepath, metadata = detail
8687
try:
87-
upload_model(registry_name, filepath)
88+
upload_model(name=registry_name, model=filepath, metadata=metadata)
8889
rank_zero_debug(f"Finished uploading: {filepath}")
8990
except Exception as ex:
9091
rank_zero_warn(f"Upload failed {filepath}: {ex}")
@@ -103,10 +104,10 @@ def _worker_loop(self) -> None:
103104
rank_zero_warn(f"Unknown task: {task}")
104105
self.task_queue.task_done()
105106

106-
def queue_upload(self, registry_name: str, filepath: str) -> None:
107+
def queue_upload(self, registry_name: str, filepath: str, metadata: Optional[dict] = None) -> None:
107108
"""Queue an upload task."""
108109
self.upload_count += 1
109-
self.task_queue.put((Action.UPLOAD, (registry_name, filepath)))
110+
self.task_queue.put((Action.UPLOAD, (registry_name, filepath, metadata)))
110111
rank_zero_debug(f"Queued upload: {filepath} (pending uploads: {self.upload_count})")
111112

112113
def queue_remove(self, trainer: "pl.Trainer", filepath: str) -> None:
@@ -148,15 +149,22 @@ def __init__(self, model_name: Optional[str]) -> None:
148149
self._model_manager = ModelManager()
149150

150151
@rank_zero_only
151-
def _upload_model(self, filepath: str) -> None:
152+
def _upload_model(self, filepath: str, metadata: Optional[dict] = None) -> None:
152153
# todo: use filename as version but need to validate that such version does not exists yet
153154
if not self.model_registry:
154155
raise RuntimeError(
155156
"Model name is not specified neither updated by `setup` method via Trainer."
156157
" Please set the model name before uploading or ensure that `setup` method is called."
157158
)
159+
if not metadata:
160+
metadata = {}
161+
# Add the integration name to the metadata
162+
mro = inspect.getmro(type(self))
163+
abc_index = mro.index(LitModelCheckpointMixin)
164+
ckpt_class = mro[abc_index - 1]
165+
metadata.update({"litModels_integration": ckpt_class.__name__})
158166
# Add to queue instead of uploading directly
159-
get_model_manager().queue_upload(self.model_registry, filepath)
167+
get_model_manager().queue_upload(registry_name=self.model_registry, filepath=filepath, metadata=metadata)
160168

161169
@rank_zero_only
162170
def _remove_model(self, trainer: "pl.Trainer", filepath: str) -> None:

src/litmodels/integrations/duplicate.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,22 @@
44
from pathlib import Path
55
from typing import Optional
66

7-
from litmodels import upload_model
7+
from lightning_utilities import module_available
8+
9+
from litmodels.io import upload_model_files
10+
11+
if module_available("huggingface_hub"):
12+
from huggingface_hub import snapshot_download
13+
else:
14+
snapshot_download = None
815

916

1017
def duplicate_hf_model(
11-
hf_model: str, lit_model: Optional[str] = None, local_workdir: Optional[str] = None, verbose: int = 1
18+
hf_model: str,
19+
lit_model: Optional[str] = None,
20+
local_workdir: Optional[str] = None,
21+
verbose: int = 1,
22+
metadata: Optional[dict] = None,
1223
) -> str:
1324
"""Downloads the model from Hugging Face and uploads it to Lightning Cloud.
1425
@@ -18,13 +29,12 @@ def duplicate_hf_model(
1829
local_workdir:
1930
The local working directory to use for the duplication process. If not set a temp folder will be created.
2031
verbose: Shot a progress bar for the upload.
32+
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
2133
2234
Returns:
2335
The name of the duplicated model in Lightning Cloud.
2436
"""
25-
try:
26-
from huggingface_hub import snapshot_download
27-
except ModuleNotFoundError:
37+
if not snapshot_download:
2838
raise ModuleNotFoundError(
2939
"Hugging Face Hub is not installed. Please install it with `pip install huggingface_hub`."
3040
)
@@ -52,5 +62,8 @@ def duplicate_hf_model(
5262
# Upload the model to Lightning Cloud
5363
if not lit_model:
5464
lit_model = model_name
55-
model = upload_model(name=lit_model, model=local_workdir / model_name, verbose=verbose)
65+
if not metadata:
66+
metadata = {}
67+
metadata.update({"litModels_integration": "duplicate_hf_model", "hf_model": hf_model})
68+
model = upload_model_files(name=lit_model, path=local_workdir / model_name, verbose=verbose, metadata=metadata)
5669
return model.name

src/litmodels/integrations/mixins.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from abc import ABC
77
from pathlib import Path
8-
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union
8+
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
99

1010
from lightning_utilities.core.rank_zero import rank_zero_warn
1111

@@ -55,27 +55,45 @@ def _setup(
5555
temp_folder = tempfile.mkdtemp()
5656
return name, model_name, temp_folder
5757

58+
def _upload_model_files(
59+
self, name: str, path: Union[str, Path, List[Union[str, Path]]], metadata: Optional[dict] = None
60+
) -> None:
61+
"""Upload the model files to the registry."""
62+
if not metadata:
63+
metadata = {}
64+
# Add the integration name to the metadata
65+
mro = inspect.getmro(type(self))
66+
abc_index = mro.index(ModelRegistryMixin)
67+
mixin_class = mro[abc_index - 1]
68+
metadata.update({"litModels_integration": mixin_class.__name__})
69+
upload_model_files(name=name, path=path, metadata=metadata)
70+
5871

5972
class PickleRegistryMixin(ModelRegistryMixin):
6073
"""Mixin for pickle registry integration."""
6174

6275
def upload_model(
63-
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
76+
self,
77+
name: Optional[str] = None,
78+
version: Optional[str] = None,
79+
temp_folder: Union[str, Path, None] = None,
80+
metadata: Optional[dict] = None,
6481
) -> None:
6582
"""Push the model to the registry.
6683
6784
Args:
6885
name: The name of the model. If not use the class name.
6986
version: The version of the model. If None, the latest version is used.
7087
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
88+
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
7189
"""
7290
name, model_name, temp_folder = self._setup(name, temp_folder)
7391
pickle_path = Path(temp_folder) / f"{model_name}.pkl"
7492
with open(pickle_path, "wb") as fp:
7593
pickle.dump(self, fp, protocol=pickle.HIGHEST_PROTOCOL)
7694
if version:
7795
name = f"{name}:{version}"
78-
upload_model_files(name=name, path=pickle_path)
96+
self._upload_model_files(name=name, path=pickle_path, metadata=metadata)
7997

8098
@classmethod
8199
def download_model(
@@ -128,14 +146,19 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "torch.nn.Module":
128146
return instance
129147

130148
def upload_model(
131-
self, name: Optional[str] = None, version: Optional[str] = None, temp_folder: Union[str, Path, None] = None
149+
self,
150+
name: Optional[str] = None,
151+
version: Optional[str] = None,
152+
temp_folder: Union[str, Path, None] = None,
153+
metadata: Optional[dict] = None,
132154
) -> None:
133155
"""Push the model to the registry.
134156
135157
Args:
136158
name: The name of the model. If not use the class name.
137159
version: The version of the model. If None, the latest version is used.
138160
temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
161+
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
139162
"""
140163
import torch
141164

@@ -145,17 +168,18 @@ def upload_model(
145168

146169
name, model_name, temp_folder = self._setup(name, temp_folder)
147170

171+
init_kwargs_path = None
148172
if self.__init_kwargs:
149173
try:
150174
# Save the model arguments to a JSON file
151175
init_kwargs_path = Path(temp_folder) / f"{model_name}__init_kwargs.json"
152176
with open(init_kwargs_path, "w") as fp:
153177
json.dump(self.__init_kwargs, fp)
154-
except Exception as e:
178+
except Exception as ex:
155179
raise RuntimeError(
156-
f"Failed to save model arguments: {e}."
180+
f"Failed to save model arguments: {ex}."
157181
" Ensure the model's arguments are JSON serializable or use `PickleRegistryMixin`."
158-
) from e
182+
) from ex
159183
elif not hasattr(self, "__init_kwargs"):
160184
rank_zero_warn(
161185
"The child class is missing `__init_kwargs`."
@@ -168,7 +192,10 @@ def upload_model(
168192
model_registry = f"{name}:{version}" if version else name
169193
# todo: consider creating another temp folder and copying these two files
170194
# todo: updating SDK to support uploading just specific files
171-
upload_model_files(name=model_registry, path=[torch_state_dict_path, init_kwargs_path])
195+
uploaded_files = [torch_state_dict_path]
196+
if init_kwargs_path:
197+
uploaded_files.append(init_kwargs_path)
198+
self._upload_model_files(name=model_registry, path=uploaded_files, metadata=metadata)
172199

173200
@classmethod
174201
def download_model(

src/litmodels/io/cloud.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from lightning_sdk.models import download_model as sdk_download_model
1111
from lightning_sdk.models import upload_model as sdk_upload_model
1212

13+
import litmodels
14+
1315
if TYPE_CHECKING:
1416
from lightning_sdk.models import UploadedModelInfo
1517

@@ -46,6 +48,7 @@ def upload_model_files(
4648
progress_bar: bool = True,
4749
cloud_account: Optional[str] = None,
4850
verbose: Union[bool, int] = 1,
51+
metadata: Optional[Dict[str, str]] = None,
4952
) -> "UploadedModelInfo":
5053
"""Upload a local checkpoint file to the model store.
5154
@@ -57,13 +60,18 @@ def upload_model_files(
5760
cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
5861
automatically.
5962
verbose: Whether to print a link to the uploaded model. If set to 0, no link will be printed.
63+
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
6064
6165
"""
66+
if not metadata:
67+
metadata = {}
68+
metadata.update({"litModels": litmodels.__version__})
6269
info = sdk_upload_model(
6370
name=name,
6471
path=path,
6572
progress_bar=progress_bar,
6673
cloud_account=cloud_account,
74+
metadata=metadata,
6775
)
6876
if verbose:
6977
_print_model_link(name, verbose)

src/litmodels/io/gateway.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import tempfile
33
from pathlib import Path
4-
from typing import TYPE_CHECKING, Any, List, Optional, Union
4+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
55

66
import joblib
77
from lightning_utilities import module_available
@@ -24,6 +24,7 @@ def upload_model(
2424
cloud_account: Optional[str] = None,
2525
staging_dir: Optional[str] = None,
2626
verbose: Union[bool, int] = 1,
27+
metadata: Optional[Dict[str, str]] = None,
2728
) -> "UploadedModelInfo":
2829
"""Upload a checkpoint to the model store.
2930
@@ -37,6 +38,7 @@ def upload_model(
3738
staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
3839
be created and used.
3940
verbose: Whether to print some additional information about the uploaded model.
41+
metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
4042
4143
"""
4244
if not staging_dir:
@@ -62,6 +64,7 @@ def upload_model(
6264
progress_bar=progress_bar,
6365
cloud_account=cloud_account,
6466
verbose=verbose,
67+
metadata=metadata,
6568
)
6669

6770

tests/integrations/test_checkpoints.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import re
33
from unittest import mock
44

5+
import litmodels
56
import pytest
67

78
from tests.integrations import _SKIP_IF_LIGHTNING_MISSING, _SKIP_IF_PYTORCHLIGHTNING_MISSING
@@ -80,6 +81,7 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, monkeypatch
8081
path=mock.ANY,
8182
progress_bar=True,
8283
cloud_account=None,
84+
metadata={"litModels_integration": LitModelCheckpoint.__name__, "litModels": litmodels.__version__},
8385
)
8486
assert mock_upload_model.call_args_list == [expected_call] * 2
8587

tests/integrations/test_cloud.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
1111
from lightning_sdk.utils.resolve import _resolve_teamspace
1212
from litmodels import download_model, upload_model
13+
from litmodels.integrations.duplicate import duplicate_hf_model
1314
from litmodels.integrations.mixins import PickleRegistryMixin, PyTorchRegistryMixin
1415
from litmodels.io.cloud import _list_available_teamspaces
1516

@@ -288,6 +289,26 @@ def test_pytorch_mixin_push_and_pull():
288289
_cleanup_model(teamspace, model_name, expected_num_versions=1)
289290

290291

292+
@pytest.mark.cloud()
293+
def test_duplicate_real_hf_model(tmp_path):
294+
"""Verify that the HF model can be duplicated to the teamspace"""
295+
296+
# model name with random hash
297+
model_name = f"litmodels_hf_model+{os.urandom(8).hex()}"
298+
teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None)
299+
org_team = f"{teamspace.owner.name}/{teamspace.name}"
300+
301+
duplicate_hf_model(hf_model="google/t5-efficient-tiny", lit_model=f"{org_team}/{model_name}")
302+
303+
client = GridRestClient()
304+
model = client.models_store_get_model_by_name(
305+
project_owner_name=teamspace.owner.name,
306+
project_name=teamspace.name,
307+
model_name=model_name,
308+
)
309+
client.models_store_delete_model(project_id=teamspace.id, model_id=model.id)
310+
311+
291312
@pytest.mark.cloud()
292313
def test_list_available_teamspaces():
293314
teams = _list_available_teamspaces()

tests/integrations/test_duplicate.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,31 @@
11
import os
2+
from unittest import mock
23

3-
import pytest
4-
from lightning_sdk.lightning_cloud.rest_client import GridRestClient
5-
from lightning_sdk.utils.resolve import _resolve_teamspace
64
from litmodels.integrations.duplicate import duplicate_hf_model
75

8-
from tests.integrations import LIT_ORG, LIT_TEAMSPACE
96

10-
11-
@pytest.mark.cloud()
12-
def test_duplicate_hf_model(tmp_path):
7+
@mock.patch("litmodels.integrations.duplicate.snapshot_download")
8+
@mock.patch("litmodels.integrations.duplicate.upload_model_files")
9+
def test_duplicate_hf_model(mock_upload_model, mock_snapshot_download, tmp_path):
1310
"""Verify that the HF model can be duplicated to the teamspace"""
1411

12+
hf_model = "google/t5-efficient-tiny"
1513
# model name with random hash
1614
model_name = f"litmodels_hf_model+{os.urandom(8).hex()}"
17-
teamspace = _resolve_teamspace(org=LIT_ORG, teamspace=LIT_TEAMSPACE, user=None)
18-
org_team = f"{teamspace.owner.name}/{teamspace.name}"
19-
20-
duplicate_hf_model(hf_model="google/t5-efficient-tiny", lit_model=f"{org_team}/{model_name}")
15+
duplicate_hf_model(hf_model=hf_model, lit_model=model_name, local_workdir=str(tmp_path))
2116

22-
client = GridRestClient()
23-
model = client.models_store_get_model_by_name(
24-
project_owner_name=teamspace.owner.name,
25-
project_name=teamspace.name,
26-
model_name=model_name,
17+
mock_snapshot_download.assert_called_with(
18+
repo_id=hf_model,
19+
revision="main",
20+
repo_type="model",
21+
local_dir=tmp_path / hf_model.replace("/", "_"),
22+
local_dir_use_symlinks=True,
23+
ignore_patterns=[".cache*"],
24+
max_workers=os.cpu_count(),
25+
)
26+
mock_upload_model.assert_called_with(
27+
name=f"{model_name}",
28+
path=tmp_path / hf_model.replace("/", "_"),
29+
metadata={"hf_model": hf_model, "litModels_integration": "duplicate_hf_model"},
30+
verbose=1,
2731
)
28-
client.models_store_delete_model(project_id=teamspace.id, model_id=model.id)

0 commit comments

Comments
 (0)