Skip to content

Commit d984ad6

Browse files
committed
with removal...
1 parent 1c1ef5e commit d984ad6

File tree

1 file changed

+63
-40
lines changed

1 file changed

+63
-40
lines changed

src/litmodels/integrations/checkpoints.py

Lines changed: 63 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -30,48 +30,68 @@
3030

3131
# Create a singleton upload manager
3232
@lru_cache(maxsize=None)
33-
def get_upload_manager():
33+
def get_model_manager():
3434
"""Get or create the singleton upload manager."""
35-
return ModelUploadManager()
35+
return ModelManager()
3636

3737

38-
class ModelUploadManager:
39-
"""Manages asynchronous model uploads in a background thread."""
38+
class ModelManager:
39+
"""Manages uploads and removals with a single queue but separate counters."""
4040

4141
def __init__(self):
42-
"""Initialize the upload manager with a queue and worker thread."""
43-
self.queue = queue.Queue()
44-
self.pending_count = 0
45-
self._lock = threading.Lock()
46-
self._worker = threading.Thread(target=self._upload_worker, daemon=True)
42+
"""Initialize the ModelManager with a task queue and counters."""
43+
self.task_queue = queue.Queue()
44+
self.upload_count = 0
45+
self.remove_count = 0
46+
self._worker = threading.Thread(target=self._worker_loop, daemon=True)
4747
self._worker.start()
4848

49-
def _upload_worker(self):
50-
"""Worker thread that processes uploads from the queue."""
49+
def _worker_loop(self):
5150
while True:
52-
task = self.queue.get()
51+
task = self.task_queue.get()
5352
if task is None:
54-
break # Signal to exit
55-
56-
registry_name, filepath = task
57-
try: # Actual upload happens here
58-
upload_model(name=registry_name, model=filepath)
59-
rank_zero_debug(f"Successfully uploaded model: {registry_name}")
60-
except Exception as ex:
61-
rank_zero_warn(f"Failed to upload model {registry_name} with {filepath}:\n{ex}")
62-
finally:
63-
# Decrement the pending count and mark the task as done
64-
with self._lock:
65-
self.pending_count -= 1
66-
# Notify that the task is done
67-
self.queue.task_done()
53+
self.task_queue.task_done()
54+
break
55+
action, detail = task
56+
if action == "upload":
57+
registry_name, filepath = detail
58+
try:
59+
upload_model(registry_name, filepath)
60+
rank_zero_debug(f"Finished uploading: {filepath}")
61+
except Exception as ex:
62+
rank_zero_warn(f"Upload failed {filepath}: {ex}")
63+
finally:
64+
self.upload_count -= 1
65+
elif action == "remove":
66+
trainer, filepath = detail
67+
try:
68+
trainer.strategy.remove_checkpoint(filepath)
69+
rank_zero_debug(f"Removed file: {filepath}")
70+
except Exception as ex:
71+
rank_zero_warn(f"Removal failed {filepath}: {ex}")
72+
finally:
73+
self.remove_count -= 1
74+
else:
75+
rank_zero_warn(f"Unknown task: {task}")
76+
self.task_queue.task_done()
6877

6978
def queue_upload(self, registry_name: str, filepath: str):
70-
"""Queue a model for background upload."""
71-
with self._lock:
72-
self.pending_count += 1
73-
self.queue.put((registry_name, filepath))
74-
rank_zero_debug(f"Queued model {registry_name} for upload. Pending uploads: {self.pending_count}")
79+
"""Queue an upload task."""
80+
self.upload_count += 1
81+
self.task_queue.put(("upload", (registry_name, filepath)))
82+
rank_zero_debug(f"Queued upload: {filepath} (pending uploads: {self.upload_count})")
83+
84+
def queue_remove(self, trainer: "pl.Trainer", filepath: str):
85+
"""Queue a removal task."""
86+
self.remove_count += 1
87+
self.task_queue.put(("remove", (trainer, filepath)))
88+
rank_zero_debug(f"Queued removal: {filepath} (pending removals: {self.remove_count})")
89+
90+
def shutdown(self):
91+
"""Shut down the manager and wait for all tasks to complete."""
92+
self.task_queue.put(None)
93+
self.task_queue.join()
94+
rank_zero_debug("Manager shut down.")
7595

7696

7797
# Base class to be inherited
@@ -105,7 +125,12 @@ def _upload_model(self, filepath: str) -> None:
105125
" Please set the model name before uploading or ensure that `setup` method is called."
106126
)
107127
# Add to queue instead of uploading directly
108-
get_upload_manager().queue_upload(self.model_registry, filepath)
128+
get_model_manager().queue_upload(self.model_registry, filepath)
129+
130+
@rank_zero_only
131+
def _remove_model(self, trainer: "pl.Trainer", filepath: str) -> None:
132+
"""Remove the local version of the model if requested."""
133+
get_model_manager().queue_remove(trainer, filepath)
109134

110135
def default_model_name(self, pl_model: "pl.LightningModule") -> str:
111136
"""Generate a default model name based on the class name and timestamp."""
@@ -177,13 +202,12 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") ->
177202
"""Extend the on_fit_end method to ensure all uploads are completed."""
178203
super().on_fit_end(trainer, pl_module)
179204
# Wait for all uploads to finish
180-
get_upload_manager().queue.join()
205+
get_model_manager().shutdown()
181206

182207
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
183208
"""Extend the remove checkpoint method to remove the model from the registry."""
184-
# super()._remove_checkpoint(trainer, filepath)
185-
# todo: need to implement another queue for removing the model from the registry after upload has finished
186-
pass
209+
if trainer.is_global_zero: # Only remove from the main process
210+
self._remove_model(trainer, filepath)
187211

188212

189213
if _PYTORCHLIGHTNING_AVAILABLE:
@@ -217,10 +241,9 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") ->
217241
"""Extend the on_fit_end method to ensure all uploads are completed."""
218242
super().on_fit_end(trainer, pl_module)
219243
# Wait for all uploads to finish
220-
get_upload_manager().queue.join()
244+
get_model_manager().shutdown()
221245

222246
def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
223247
"""Extend the remove checkpoint method to remove the model from the registry."""
224-
# super()._remove_checkpoint(trainer, filepath)
225-
# todo: need to implement another queue for removing the model from the registry after upload has finished
226-
pass
248+
if trainer.is_global_zero: # Only remove from the main process
249+
self._remove_model(trainer, filepath)

0 commit comments

Comments
 (0)