|
30 | 30 |
|
31 | 31 | # Create a singleton upload manager |
32 | 32 | @lru_cache(maxsize=None) |
33 | | -def get_upload_manager(): |
| 33 | +def get_model_manager(): |
34 | 34 | """Get or create the singleton upload manager.""" |
35 | | - return ModelUploadManager() |
| 35 | + return ModelManager() |
36 | 36 |
|
37 | 37 |
|
38 | | -class ModelUploadManager: |
39 | | - """Manages asynchronous model uploads in a background thread.""" |
| 38 | +class ModelManager: |
| 39 | + """Manages uploads and removals with a single queue but separate counters.""" |
40 | 40 |
|
41 | 41 | def __init__(self): |
42 | | - """Initialize the upload manager with a queue and worker thread.""" |
43 | | - self.queue = queue.Queue() |
44 | | - self.pending_count = 0 |
45 | | - self._lock = threading.Lock() |
46 | | - self._worker = threading.Thread(target=self._upload_worker, daemon=True) |
| 42 | + """Initialize the ModelManager with a task queue and counters.""" |
| 43 | + self.task_queue = queue.Queue() |
| 44 | + self.upload_count = 0 |
| 45 | + self.remove_count = 0 |
| 46 | + self._worker = threading.Thread(target=self._worker_loop, daemon=True) |
47 | 47 | self._worker.start() |
48 | 48 |
|
49 | | - def _upload_worker(self): |
50 | | - """Worker thread that processes uploads from the queue.""" |
| 49 | + def _worker_loop(self): |
51 | 50 | while True: |
52 | | - task = self.queue.get() |
| 51 | + task = self.task_queue.get() |
53 | 52 | if task is None: |
54 | | - break # Signal to exit |
55 | | - |
56 | | - registry_name, filepath = task |
57 | | - try: # Actual upload happens here |
58 | | - upload_model(name=registry_name, model=filepath) |
59 | | - rank_zero_debug(f"Successfully uploaded model: {registry_name}") |
60 | | - except Exception as ex: |
61 | | - rank_zero_warn(f"Failed to upload model {registry_name} with {filepath}:\n{ex}") |
62 | | - finally: |
63 | | - # Decrement the pending count and mark the task as done |
64 | | - with self._lock: |
65 | | - self.pending_count -= 1 |
66 | | - # Notify that the task is done |
67 | | - self.queue.task_done() |
| 53 | + self.task_queue.task_done() |
| 54 | + break |
| 55 | + action, detail = task |
| 56 | + if action == "upload": |
| 57 | + registry_name, filepath = detail |
| 58 | + try: |
| 59 | + upload_model(registry_name, filepath) |
| 60 | + rank_zero_debug(f"Finished uploading: {filepath}") |
| 61 | + except Exception as ex: |
| 62 | + rank_zero_warn(f"Upload failed {filepath}: {ex}") |
| 63 | + finally: |
| 64 | + self.upload_count -= 1 |
| 65 | + elif action == "remove": |
| 66 | + trainer, filepath = detail |
| 67 | + try: |
| 68 | + trainer.strategy.remove_checkpoint(filepath) |
| 69 | + rank_zero_debug(f"Removed file: {filepath}") |
| 70 | + except Exception as ex: |
| 71 | + rank_zero_warn(f"Removal failed {filepath}: {ex}") |
| 72 | + finally: |
| 73 | + self.remove_count -= 1 |
| 74 | + else: |
| 75 | + rank_zero_warn(f"Unknown task: {task}") |
| 76 | + self.task_queue.task_done() |
68 | 77 |
|
69 | 78 | def queue_upload(self, registry_name: str, filepath: str): |
70 | | - """Queue a model for background upload.""" |
71 | | - with self._lock: |
72 | | - self.pending_count += 1 |
73 | | - self.queue.put((registry_name, filepath)) |
74 | | - rank_zero_debug(f"Queued model {registry_name} for upload. Pending uploads: {self.pending_count}") |
| 79 | + """Queue an upload task.""" |
| 80 | + self.upload_count += 1 |
| 81 | + self.task_queue.put(("upload", (registry_name, filepath))) |
| 82 | + rank_zero_debug(f"Queued upload: {filepath} (pending uploads: {self.upload_count})") |
| 83 | + |
| 84 | + def queue_remove(self, trainer: "pl.Trainer", filepath: str): |
| 85 | + """Queue a removal task.""" |
| 86 | + self.remove_count += 1 |
| 87 | + self.task_queue.put(("remove", (trainer, filepath))) |
| 88 | + rank_zero_debug(f"Queued removal: {filepath} (pending removals: {self.remove_count})") |
| 89 | + |
| 90 | + def shutdown(self): |
| 91 | + """Shut down the manager and wait for all tasks to complete.""" |
| 92 | + self.task_queue.put(None) |
| 93 | + self.task_queue.join() |
| 94 | + rank_zero_debug("Manager shut down.") |
75 | 95 |
|
76 | 96 |
|
77 | 97 | # Base class to be inherited |
@@ -105,7 +125,12 @@ def _upload_model(self, filepath: str) -> None: |
105 | 125 | " Please set the model name before uploading or ensure that `setup` method is called." |
106 | 126 | ) |
107 | 127 | # Add to queue instead of uploading directly |
108 | | - get_upload_manager().queue_upload(self.model_registry, filepath) |
| 128 | + get_model_manager().queue_upload(self.model_registry, filepath) |
| 129 | + |
| 130 | + @rank_zero_only |
| 131 | + def _remove_model(self, trainer: "pl.Trainer", filepath: str) -> None: |
| 132 | + """Remove the local version of the model if requested.""" |
| 133 | + get_model_manager().queue_remove(trainer, filepath) |
109 | 134 |
|
110 | 135 | def default_model_name(self, pl_model: "pl.LightningModule") -> str: |
111 | 136 | """Generate a default model name based on the class name and timestamp.""" |
@@ -177,13 +202,12 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> |
177 | 202 | """Extend the on_fit_end method to ensure all uploads are completed.""" |
178 | 203 | super().on_fit_end(trainer, pl_module) |
179 | 204 | # Wait for all uploads to finish |
180 | | - get_upload_manager().queue.join() |
| 205 | + get_model_manager().shutdown() |
181 | 206 |
|
182 | 207 | def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: |
183 | 208 | """Extend the remove checkpoint method to remove the model from the registry.""" |
184 | | - # super()._remove_checkpoint(trainer, filepath) |
185 | | - # todo: need to implement another queue for removing the model from the registry after upload has finished |
186 | | - pass |
| 209 | + if trainer.is_global_zero: # Only remove from the main process |
| 210 | + self._remove_model(trainer, filepath) |
187 | 211 |
|
188 | 212 |
|
189 | 213 | if _PYTORCHLIGHTNING_AVAILABLE: |
@@ -217,10 +241,9 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> |
217 | 241 | """Extend the on_fit_end method to ensure all uploads are completed.""" |
218 | 242 | super().on_fit_end(trainer, pl_module) |
219 | 243 | # Wait for all uploads to finish |
220 | | - get_upload_manager().queue.join() |
| 244 | + get_model_manager().shutdown() |
221 | 245 |
|
222 | 246 | def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: |
223 | 247 | """Extend the remove checkpoint method to remove the model from the registry.""" |
224 | | - # super()._remove_checkpoint(trainer, filepath) |
225 | | - # todo: need to implement another queue for removing the model from the registry after upload has finished |
226 | | - pass |
| 248 | + if trainer.is_global_zero: # Only remove from the main process |
| 249 | + self._remove_model(trainer, filepath) |
0 commit comments