Skip to content

Commit 5bee354

Browse files
authored
ckpt: use checkpoint name as version (#92)
1 parent 697b20f commit 5bee354

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

docs/source/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def _convert_markdown(path_in: str, path_out: str) -> None:
274274
"numpy": ("https://numpy.org/doc/stable/", None),
275275
}
276276

277-
# -- Options for todo extension ----------------------------------------------
277+
# -- Options for to-do extension ----------------------------------------------
278278

279279
# If true, `todo` and `todoList` produce output, else they produce nothing.
280280
todo_include_todos = True

src/litmodels/integrations/checkpoints.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import os.path
23
import queue
34
import threading
45
from abc import ABC
@@ -150,12 +151,16 @@ def __init__(self, model_name: Optional[str]) -> None:
150151

151152
@rank_zero_only
152153
def _upload_model(self, filepath: str, metadata: Optional[dict] = None) -> None:
153-
# todo: use filename as version but need to validate that such version does not exists yet
154154
if not self.model_registry:
155155
raise RuntimeError(
156156
"Model name is not specified neither updated by `setup` method via Trainer."
157157
" Please set the model name before uploading or ensure that `setup` method is called."
158158
)
159+
model_registry = self.model_registry
160+
if os.path.isfile(filepath):
161+
# parse the file name as version
162+
version, _ = os.path.splitext(os.path.basename(filepath))
163+
model_registry += f":{version}"
159164
if not metadata:
160165
metadata = {}
161166
# Add the integration name to the metadata
@@ -164,7 +169,7 @@ def _upload_model(self, filepath: str, metadata: Optional[dict] = None) -> None:
164169
ckpt_class = mro[abc_index - 1]
165170
metadata.update({"litModels_integration": ckpt_class.__name__})
166171
# Add to queue instead of uploading directly
167-
get_model_manager().queue_upload(registry_name=self.model_registry, filepath=filepath, metadata=metadata)
172+
get_model_manager().queue_upload(registry_name=model_registry, filepath=filepath, metadata=metadata)
168173

169174
@rank_zero_only
170175
def _remove_model(self, trainer: "pl.Trainer", filepath: str) -> None:

tests/integrations/test_checkpoints.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,16 @@ def test_lightning_checkpoint_callback(mock_auth, mock_upload_model, monkeypatch
7878
trainer.fit(BoringModel())
7979

8080
assert mock_auth.call_count == 1
81-
expected_call = mock.call(
82-
name=f"{expected_org}/{expected_teamspace}/{expected_model}",
83-
path=mock.ANY,
84-
progress_bar=True,
85-
cloud_account=None,
86-
metadata={"litModels_integration": LitModelCheckpoint.__name__, "litModels": litmodels.__version__},
87-
)
88-
assert mock_upload_model.call_args_list == [expected_call] * 2
81+
assert mock_upload_model.call_args_list == [
82+
mock.call(
83+
name=f"{expected_org}/{expected_teamspace}/{expected_model}:{v}",
84+
path=mock.ANY,
85+
progress_bar=True,
86+
cloud_account=None,
87+
metadata={"litModels_integration": LitModelCheckpoint.__name__, "litModels": litmodels.__version__},
88+
)
89+
for v in ("epoch=0-step=64", "epoch=1-step=128")
90+
]
8991

9092
# Verify paths match the expected pattern
9193
for call_args in mock_upload_model.call_args_list:

0 commit comments

Comments
 (0)