Skip to content

Commit 0c3709e

Browse files
uploading with queue on background (#77)
* uploading with queue on background * with removal... --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6b7142f commit 0c3709e

File tree

3 files changed

+146
-7
lines changed

3 files changed

+146
-7
lines changed

src/litmodels/integrations/checkpoints.py

Lines changed: 134 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
1+
import queue
2+
import threading
13
from abc import ABC
24
from datetime import datetime
5+
from functools import lru_cache
36
from typing import TYPE_CHECKING, Any, Optional
47

58
from lightning_sdk.lightning_cloud.login import Auth
69
from lightning_sdk.utils.resolve import _resolve_teamspace
7-
from lightning_utilities.core.rank_zero import rank_zero_only, rank_zero_warn
10+
from lightning_utilities import StrEnum
11+
from lightning_utilities.core.rank_zero import rank_zero_debug, rank_zero_only, rank_zero_warn
812

913
from litmodels import upload_model
1014
from litmodels.integrations.imports import _LIGHTNING_AVAILABLE, _PYTORCHLIGHTNING_AVAILABLE
@@ -25,12 +29,106 @@
2529
import pytorch_lightning as pl
2630

2731

32+
# Create a singleton upload manager
33+
@lru_cache(maxsize=None)
34+
def get_model_manager() -> "ModelManager":
35+
"""Get or create the singleton upload manager."""
36+
return ModelManager()
37+
38+
39+
# enumerate the possible actions
40+
class Action(StrEnum):
41+
"""Enumeration of possible actions for the ModelManager."""
42+
43+
UPLOAD = "upload"
44+
REMOVE = "remove"
45+
46+
47+
class ModelManager:
48+
"""Manages uploads and removals with a single queue but separate counters."""
49+
50+
task_queue: queue.Queue
51+
52+
def __init__(self) -> None:
53+
"""Initialize the ModelManager with a task queue and counters."""
54+
self.task_queue = queue.Queue()
55+
self.upload_count = 0
56+
self.remove_count = 0
57+
self._worker = threading.Thread(target=self._worker_loop, daemon=True)
58+
self._worker.start()
59+
60+
def __getstate__(self) -> dict:
61+
"""Get the state of the ModelManager for pickling."""
62+
state = self.__dict__.copy()
63+
del state["task_queue"]
64+
del state["_worker"]
65+
return state
66+
67+
def __setstate__(self, state: dict) -> None:
68+
"""Set the state of the ModelManager after unpickling."""
69+
self.__dict__.update(state)
70+
import queue
71+
import threading
72+
73+
self.task_queue = queue.Queue()
74+
self._worker = threading.Thread(target=self._worker_loop, daemon=True)
75+
self._worker.start()
76+
77+
def _worker_loop(self) -> None:
78+
while True:
79+
task = self.task_queue.get()
80+
if task is None:
81+
self.task_queue.task_done()
82+
break
83+
action, detail = task
84+
if action == Action.UPLOAD:
85+
registry_name, filepath = detail
86+
try:
87+
upload_model(registry_name, filepath)
88+
rank_zero_debug(f"Finished uploading: {filepath}")
89+
except Exception as ex:
90+
rank_zero_warn(f"Upload failed {filepath}: {ex}")
91+
finally:
92+
self.upload_count -= 1
93+
elif action == Action.REMOVE:
94+
trainer, filepath = detail
95+
try:
96+
trainer.strategy.remove_checkpoint(filepath)
97+
rank_zero_debug(f"Removed file: {filepath}")
98+
except Exception as ex:
99+
rank_zero_warn(f"Removal failed {filepath}: {ex}")
100+
finally:
101+
self.remove_count -= 1
102+
else:
103+
rank_zero_warn(f"Unknown task: {task}")
104+
self.task_queue.task_done()
105+
106+
def queue_upload(self, registry_name: str, filepath: str) -> None:
107+
"""Queue an upload task."""
108+
self.upload_count += 1
109+
self.task_queue.put((Action.UPLOAD, (registry_name, filepath)))
110+
rank_zero_debug(f"Queued upload: {filepath} (pending uploads: {self.upload_count})")
111+
112+
def queue_remove(self, trainer: "pl.Trainer", filepath: str) -> None:
113+
"""Queue a removal task."""
114+
self.remove_count += 1
115+
self.task_queue.put((Action.REMOVE, (trainer, filepath)))
116+
rank_zero_debug(f"Queued removal: {filepath} (pending removals: {self.remove_count})")
117+
118+
def shutdown(self) -> None:
119+
"""Shut down the manager and wait for all tasks to complete."""
120+
self.task_queue.put(None)
121+
self.task_queue.join()
122+
rank_zero_debug("Manager shut down.")
123+
124+
28125
# Base class to be inherited
29126
class LitModelCheckpointMixin(ABC):
30127
"""Mixin class for LitModel checkpoint functionality."""
31128

32129
_datetime_stamp: str
33130
model_registry: Optional[str] = None
131+
_model_manager: ModelManager
34132

35133
def __init__(self, model_name: Optional[str]) -> None:
36134
"""Initialize with model name."""
@@ -47,16 +145,23 @@ def __init__(self, model_name: Optional[str]) -> None:
47145
except Exception:
48146
raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")
49147

148+
self._model_manager = ModelManager()
149+
50150
@rank_zero_only
51151
def _upload_model(self, filepath: str) -> None:
52-
# todo: uploading on background so training does nt stops
53152
# todo: use filename as version but need to validate that such version does not exists yet
54153
if not self.model_registry:
55154
raise RuntimeError(
56155
"Model name is not specified neither updated by `setup` method via Trainer."
57156
" Please set the model name before uploading or ensure that `setup` method is called."
58157
)
59-
upload_model(name=self.model_registry, model=filepath)
158+
# Add to queue instead of uploading directly
159+
get_model_manager().queue_upload(self.model_registry, filepath)
160+
161+
@rank_zero_only
162+
def _remove_model(self, trainer: "pl.Trainer", filepath: str) -> None:
163+
"""Remove the local version of the model if requested."""
164+
get_model_manager().queue_remove(trainer, filepath)
60165

61166
def default_model_name(self, pl_model: "pl.LightningModule") -> str:
62167
"""Generate a default model name based on the class name and timestamp."""
@@ -115,15 +220,26 @@ def __init__(self, *args: Any, model_name: Optional[str] = None, **kwargs: Any)
115220

116221
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
117222
"""Setup the checkpoint callback."""
118-
super().setup(trainer, pl_module, stage)
223+
_LightningModelCheckpoint.setup(self, trainer, pl_module, stage)
119224
self._update_model_name(pl_module)
120225

121226
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
122227
"""Extend the save checkpoint method to upload the model."""
123-
super()._save_checkpoint(trainer, filepath)
228+
_LightningModelCheckpoint._save_checkpoint(self, trainer, filepath)
124229
if trainer.is_global_zero: # Only upload from the main process
125230
self._upload_model(filepath)
126231

232+
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
233+
"""Extend the on_fit_end method to ensure all uploads are completed."""
234+
_LightningModelCheckpoint.on_fit_end(self, trainer, pl_module)
235+
# Wait for all uploads to finish
236+
get_model_manager().shutdown()
237+
238+
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
239+
"""Extend the remove checkpoint method to remove the model from the registry."""
240+
if trainer.is_global_zero: # Only remove from the main process
241+
self._remove_model(trainer, filepath)
242+
127243

128244
if _PYTORCHLIGHTNING_AVAILABLE:
129245

@@ -143,11 +259,22 @@ def __init__(self, *args: Any, model_name: Optional[str] = None, **kwargs: Any)
143259

144260
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None:
145261
"""Setup the checkpoint callback."""
146-
super().setup(trainer, pl_module, stage)
262+
_PytorchLightningModelCheckpoint.setup(self, trainer, pl_module, stage)
147263
self._update_model_name(pl_module)
148264

149265
def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
150266
"""Extend the save checkpoint method to upload the model."""
151-
super()._save_checkpoint(trainer, filepath)
267+
_PytorchLightningModelCheckpoint._save_checkpoint(self, trainer, filepath)
152268
if trainer.is_global_zero: # Only upload from the main process
153269
self._upload_model(filepath)
270+
271+
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
272+
"""Extend the on_fit_end method to ensure all uploads are completed."""
273+
_PytorchLightningModelCheckpoint.on_fit_end(self, trainer, pl_module)
274+
# Wait for all uploads to finish
275+
get_model_manager().shutdown()
276+
277+
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
278+
"""Extend the remove checkpoint method to remove the model from the registry."""
279+
if trainer.is_global_zero: # Only remove from the main process
280+
self._remove_model(trainer, filepath)

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Pytest configuration for integration tests."""
2+
3+
import pytest
4+
from litmodels.integrations.checkpoints import get_model_manager
5+
6+
7+
@pytest.fixture(autouse=True)
8+
def reset_model_manager():
9+
get_model_manager.cache_clear()
10+
# Optionally, call it once to initialize immediately
11+
return get_model_manager()

tests/integrations/test_checkpoints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,4 +104,5 @@ def test_lightning_checkpointing_pickleable(mock_auth, importing):
104104
from litmodels.integrations.checkpoints import PytorchLightningModelCheckpoint as LitModelCheckpoint
105105

106106
ckpt = LitModelCheckpoint(model_name="org-name/teamspace/model-name")
107+
assert mock_auth.call_count == 1
107108
pickle.dumps(ckpt)

0 commit comments

Comments
 (0)