1+ import queue
2+ import threading
13from abc import ABC
24from datetime import datetime
5+ from functools import lru_cache
36from typing import TYPE_CHECKING , Any , Optional
47
58from lightning_sdk .lightning_cloud .login import Auth
69from lightning_sdk .utils .resolve import _resolve_teamspace
7- from lightning_utilities .core .rank_zero import rank_zero_only , rank_zero_warn
10+ from lightning_utilities import StrEnum
11+ from lightning_utilities .core .rank_zero import rank_zero_debug , rank_zero_only , rank_zero_warn
812
913from litmodels import upload_model
1014from litmodels .integrations .imports import _LIGHTNING_AVAILABLE , _PYTORCHLIGHTNING_AVAILABLE
2529 import pytorch_lightning as pl
2630
2731
32+ # Create a singleton upload manager
33+ @lru_cache (maxsize = None )
34+ def get_model_manager () -> "ModelManager" :
35+ """Get or create the singleton upload manager."""
36+ return ModelManager ()
37+
38+
39+ # enumerate the possible actions
40+ class Action (StrEnum ):
41+ """Enumeration of possible actions for the ModelManager."""
42+
43+ UPLOAD = "upload"
44+ REMOVE = "remove"
45+
46+
47+ class ModelManager :
48+ """Manages uploads and removals with a single queue but separate counters."""
49+
50+ task_queue : queue .Queue
51+
52+ def __init__ (self ) -> None :
53+ """Initialize the ModelManager with a task queue and counters."""
54+ self .task_queue = queue .Queue ()
55+ self .upload_count = 0
56+ self .remove_count = 0
57+ self ._worker = threading .Thread (target = self ._worker_loop , daemon = True )
58+ self ._worker .start ()
59+
60+ def __getstate__ (self ) -> dict :
61+ """Get the state of the ModelManager for pickling."""
62+ state = self .__dict__ .copy ()
63+ del state ["task_queue" ]
64+ del state ["_worker" ]
65+ return state
66+
67+ def __setstate__ (self , state : dict ) -> None :
68+ """Set the state of the ModelManager after unpickling."""
69+ self .__dict__ .update (state )
70+ import queue
71+ import threading
72+
73+ self .task_queue = queue .Queue ()
74+ self ._worker = threading .Thread (target = self ._worker_loop , daemon = True )
75+ self ._worker .start ()
76+
77+ def _worker_loop (self ) -> None :
78+ while True :
79+ task = self .task_queue .get ()
80+ if task is None :
81+ self .task_queue .task_done ()
82+ break
83+ action , detail = task
84+ if action == Action .UPLOAD :
85+ registry_name , filepath = detail
86+ try :
87+ upload_model (registry_name , filepath )
88+ rank_zero_debug (f"Finished uploading: { filepath } " )
89+ except Exception as ex :
90+ rank_zero_warn (f"Upload failed { filepath } : { ex } " )
91+ finally :
92+ self .upload_count -= 1
93+ elif action == Action .REMOVE :
94+ trainer , filepath = detail
95+ try :
96+ trainer .strategy .remove_checkpoint (filepath )
97+ rank_zero_debug (f"Removed file: { filepath } " )
98+ except Exception as ex :
99+ rank_zero_warn (f"Removal failed { filepath } : { ex } " )
100+ finally :
101+ self .remove_count -= 1
102+ else :
103+ rank_zero_warn (f"Unknown task: { task } " )
104+ self .task_queue .task_done ()
105+
106+ def queue_upload (self , registry_name : str , filepath : str ) -> None :
107+ """Queue an upload task."""
108+ self .upload_count += 1
109+ self .task_queue .put ((Action .UPLOAD , (registry_name , filepath )))
110+ rank_zero_debug (f"Queued upload: { filepath } (pending uploads: { self .upload_count } )" )
111+
112+ def queue_remove (self , trainer : "pl.Trainer" , filepath : str ) -> None :
113+ """Queue a removal task."""
114+ self .remove_count += 1
115+ self .task_queue .put ((Action .REMOVE , (trainer , filepath )))
116+ rank_zero_debug (f"Queued removal: { filepath } (pending removals: { self .remove_count } )" )
117+
118+ def shutdown (self ) -> None :
119+ """Shut down the manager and wait for all tasks to complete."""
120+ self .task_queue .put (None )
121+ self .task_queue .join ()
122+ rank_zero_debug ("Manager shut down." )
123+
124+
28125# Base class to be inherited
29126class LitModelCheckpointMixin (ABC ):
30127 """Mixin class for LitModel checkpoint functionality."""
31128
32129 _datetime_stamp : str
33130 model_registry : Optional [str ] = None
131+ _model_manager : ModelManager
34132
35133 def __init__ (self , model_name : Optional [str ]) -> None :
36134 """Initialize with model name."""
@@ -47,16 +145,23 @@ def __init__(self, model_name: Optional[str]) -> None:
47145 except Exception :
48146 raise ConnectionError ("Unable to authenticate with Lightning Cloud. Check your credentials." )
49147
148+ self ._model_manager = ModelManager ()
149+
50150 @rank_zero_only
51151 def _upload_model (self , filepath : str ) -> None :
52- # todo: uploading on background so training does nt stops
53152 # todo: use filename as version but need to validate that such version does not exists yet
54153 if not self .model_registry :
55154 raise RuntimeError (
56155 "Model name is not specified neither updated by `setup` method via Trainer."
57156 " Please set the model name before uploading or ensure that `setup` method is called."
58157 )
59- upload_model (name = self .model_registry , model = filepath )
158+ # Add to queue instead of uploading directly
159+ get_model_manager ().queue_upload (self .model_registry , filepath )
160+
161+ @rank_zero_only
162+ def _remove_model (self , trainer : "pl.Trainer" , filepath : str ) -> None :
163+ """Remove the local version of the model if requested."""
164+ get_model_manager ().queue_remove (trainer , filepath )
60165
61166 def default_model_name (self , pl_model : "pl.LightningModule" ) -> str :
62167 """Generate a default model name based on the class name and timestamp."""
@@ -115,15 +220,26 @@ def __init__(self, *args: Any, model_name: Optional[str] = None, **kwargs: Any)
115220
116221 def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
117222 """Setup the checkpoint callback."""
118- super () .setup (trainer , pl_module , stage )
223+ _LightningModelCheckpoint .setup (self , trainer , pl_module , stage )
119224 self ._update_model_name (pl_module )
120225
121226 def _save_checkpoint (self , trainer : "pl.Trainer" , filepath : str ) -> None :
122227 """Extend the save checkpoint method to upload the model."""
123- super () ._save_checkpoint (trainer , filepath )
228+ _LightningModelCheckpoint ._save_checkpoint (self , trainer , filepath )
124229 if trainer .is_global_zero : # Only upload from the main process
125230 self ._upload_model (filepath )
126231
232+ def on_fit_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
233+ """Extend the on_fit_end method to ensure all uploads are completed."""
234+ _LightningModelCheckpoint .on_fit_end (self , trainer , pl_module )
235+ # Wait for all uploads to finish
236+ get_model_manager ().shutdown ()
237+
238+ def _remove_checkpoint (self , trainer : "pl.Trainer" , filepath : str ) -> None :
239+ """Extend the remove checkpoint method to remove the model from the registry."""
240+ if trainer .is_global_zero : # Only remove from the main process
241+ self ._remove_model (trainer , filepath )
242+
127243
128244if _PYTORCHLIGHTNING_AVAILABLE :
129245
@@ -143,11 +259,22 @@ def __init__(self, *args: Any, model_name: Optional[str] = None, **kwargs: Any)
143259
144260 def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
145261 """Setup the checkpoint callback."""
146- super () .setup (trainer , pl_module , stage )
262+ _PytorchLightningModelCheckpoint .setup (self , trainer , pl_module , stage )
147263 self ._update_model_name (pl_module )
148264
149265 def _save_checkpoint (self , trainer : "pl.Trainer" , filepath : str ) -> None :
150266 """Extend the save checkpoint method to upload the model."""
151- super () ._save_checkpoint (trainer , filepath )
267+ _PytorchLightningModelCheckpoint ._save_checkpoint (self , trainer , filepath )
152268 if trainer .is_global_zero : # Only upload from the main process
153269 self ._upload_model (filepath )
270+
271+ def on_fit_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
272+ """Extend the on_fit_end method to ensure all uploads are completed."""
273+ _PytorchLightningModelCheckpoint .on_fit_end (self , trainer , pl_module )
274+ # Wait for all uploads to finish
275+ get_model_manager ().shutdown ()
276+
277+ def _remove_checkpoint (self , trainer : "pl.Trainer" , filepath : str ) -> None :
278+ """Extend the remove checkpoint method to remove the model from the registry."""
279+ if trainer .is_global_zero : # Only remove from the main process
280+ self ._remove_model (trainer , filepath )
0 commit comments