1515
1616from litmodels import upload_model
1717from litmodels .integrations .imports import _LIGHTNING_AVAILABLE , _PYTORCHLIGHTNING_AVAILABLE
18- from litmodels .io .cloud import _list_available_teamspaces
18+ from litmodels .io .cloud import _list_available_teamspaces , delete_model_version
1919
2020if _LIGHTNING_AVAILABLE :
2121 from lightning .pytorch .callbacks import ModelCheckpoint as _LightningModelCheckpoint
@@ -47,6 +47,13 @@ class Action(StrEnum):
4747 REMOVE = "remove"
4848
4949
50+ class RemoveType (StrEnum ):
51+ """Enumeration of possible remove types for the ModelManager."""
52+
53+ LOCAL = "local"
54+ CLOUD = "cloud"
55+
56+
5057class ModelManager :
5158 """Manages uploads and removals with a single queue but separate counters."""
5259
@@ -94,10 +101,16 @@ def _worker_loop(self) -> None:
94101 finally :
95102 self .upload_count -= 1
96103 elif action == Action .REMOVE :
97- trainer , filepath = detail
104+ filepath , trainer , registry_name = detail
98105 try :
99- trainer .strategy .remove_checkpoint (filepath )
100- rank_zero_debug (f"Removed file: { filepath } " )
106+ if registry_name :
107+ rank_zero_debug (f"Removing from cloud: { filepath } " )
108+ # Remove from the cloud
109+ version = os .path .splitext (os .path .basename (filepath ))[0 ]
110+ delete_model_version (name = registry_name , version = version )
111+ if trainer :
112+ rank_zero_debug (f"Removed local file: { filepath } " )
113+ trainer .strategy .remove_checkpoint (filepath )
101114 except Exception as ex :
102115 rank_zero_warn (f"Removal failed { filepath } : { ex } " )
103116 finally :
@@ -112,10 +125,12 @@ def queue_upload(self, registry_name: str, filepath: Union[str, Path], metadata:
112125 self .task_queue .put ((Action .UPLOAD , (registry_name , filepath , metadata )))
113126 rank_zero_debug (f"Queued upload: { filepath } (pending uploads: { self .upload_count } )" )
114127
115- def queue_remove (self , trainer : "pl.Trainer" , filepath : Union [str , Path ]) -> None :
128+ def queue_remove (
129+ self , filepath : Union [str , Path ], trainer : Optional ["pl.Trainer" ] = None , registry_name : Optional [str ] = None
130+ ) -> None :
116131 """Queue a removal task."""
117132 self .remove_count += 1
118- self .task_queue .put ((Action .REMOVE , (trainer , filepath )))
133+ self .task_queue .put ((Action .REMOVE , (filepath , trainer , registry_name )))
119134 rank_zero_debug (f"Queued removal: { filepath } (pending removals: { self .remove_count } )" )
120135
121136 def shutdown (self ) -> None :
@@ -133,11 +148,14 @@ class LitModelCheckpointMixin(ABC):
133148 model_registry : Optional [str ] = None
134149 _model_manager : ModelManager
135150
136- def __init__ (self , model_registry : Optional [str ], clear_all_local : bool = False ) -> None :
151+ def __init__ (
152+ self , model_registry : Optional [str ], keep_all_uploaded : bool = False , clear_all_local : bool = False
153+ ) -> None :
137154 """Initialize with model name.
138155
139156 Args:
140157 model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
158+ keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
141159 clear_all_local: Whether to clear local models after uploading to the cloud.
142160 """
143161 if not model_registry :
@@ -147,6 +165,7 @@ def __init__(self, model_registry: Optional[str], clear_all_local: bool = False)
147165 self ._datetime_stamp = datetime .now ().strftime ("%Y%m%d-%H%M" )
148166 # remove any / from beginning and end of the name
149167 self .model_registry = model_registry .strip ("/" ) if model_registry else None
168+ self ._keep_all_uploaded = keep_all_uploaded
150169 self ._clear_all_local = clear_all_local
151170
152171 try : # authenticate before anything else starts
@@ -178,15 +197,18 @@ def _upload_model(self, trainer: "pl.Trainer", filepath: Union[str, Path], metad
178197 # Add to queue instead of uploading directly
179198 get_model_manager ().queue_upload (registry_name = model_registry , filepath = filepath , metadata = metadata )
180199 if self ._clear_all_local :
181- get_model_manager ().queue_remove (trainer = trainer , filepath = filepath )
200+ get_model_manager ().queue_remove (filepath = filepath , trainer = trainer )
182201
183202 @rank_zero_only
184203 def _remove_model (self , trainer : "pl.Trainer" , filepath : Union [str , Path ]) -> None :
185204 """Remove the local version of the model if requested."""
186- if self ._clear_all_local :
205+ get_model_manager ().queue_remove (
206+ filepath = filepath ,
187207 # skip the local removal we put it in the queue right after the upload
188- return
189- get_model_manager ().queue_remove (trainer = trainer , filepath = filepath )
208+ trainer = None if self ._clear_all_local else trainer ,
209+ # skip the cloud removal if we keep all uploaded models
210+ registry_name = None if self ._keep_all_uploaded else self .model_registry ,
211+ )
190212
191213 def default_model_name (self , pl_model : "pl.LightningModule" ) -> str :
192214 """Generate a default model name based on the class name and timestamp."""
@@ -234,6 +256,7 @@ class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoin
234256
235257 Args:
236258 model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
259+ keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
237260 clear_all_local: Whether to clear local models after uploading to the cloud.
238261 *args: Additional arguments to pass to the parent class.
239262 **kwargs: Additional keyword arguments to pass to the parent class.
@@ -244,6 +267,7 @@ def __init__(
244267 * args : Any ,
245268 model_name : Optional [str ] = None ,
246269 model_registry : Optional [str ] = None ,
270+ keep_all_uploaded : bool = False ,
247271 clear_all_local : bool = False ,
248272 ** kwargs : Any ,
249273 ) -> None :
@@ -255,7 +279,10 @@ def __init__(
255279 " Please use 'model_registry' instead."
256280 )
257281 LitModelCheckpointMixin .__init__ (
258- self , model_registry = model_registry or model_name , clear_all_local = clear_all_local
282+ self ,
283+ model_registry = model_registry or model_name ,
284+ keep_all_uploaded = keep_all_uploaded ,
285+ clear_all_local = clear_all_local ,
259286 )
260287
261288 def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
@@ -288,6 +315,7 @@ class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightning
288315
289316 Args:
290317 model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
318+ keep_all_uploaded: Whether prevent deleting models from cloud if the checkpointing logic asks to do so.
291319 clear_all_local: Whether to clear local models after uploading to the cloud.
292320 args: Additional arguments to pass to the parent class.
293321 kwargs: Additional keyword arguments to pass to the parent class.
@@ -298,6 +326,7 @@ def __init__(
298326 * args : Any ,
299327 model_name : Optional [str ] = None ,
300328 model_registry : Optional [str ] = None ,
329+ keep_all_uploaded : bool = False ,
301330 clear_all_local : bool = False ,
302331 ** kwargs : Any ,
303332 ) -> None :
@@ -309,7 +338,10 @@ def __init__(
309338 " Please use 'model_registry' instead."
310339 )
311340 LitModelCheckpointMixin .__init__ (
312- self , model_registry = model_registry or model_name , clear_all_local = clear_all_local
341+ self ,
342+ model_registry = model_registry or model_name ,
343+ keep_all_uploaded = keep_all_uploaded ,
344+ clear_all_local = clear_all_local ,
313345 )
314346
315347 def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
0 commit comments