55from abc import ABC
66from datetime import datetime
77from functools import lru_cache
8- from typing import TYPE_CHECKING , Any , Optional
8+ from pathlib import Path
9+ from typing import TYPE_CHECKING , Any , Optional , Union
910
1011from lightning_sdk .lightning_cloud .login import Auth
1112from lightning_sdk .utils .resolve import _resolve_teamspace
@@ -105,13 +106,13 @@ def _worker_loop(self) -> None:
105106 rank_zero_warn (f"Unknown task: { task } " )
106107 self .task_queue .task_done ()
107108
108- def queue_upload (self , registry_name : str , filepath : str , metadata : Optional [dict ] = None ) -> None :
109+ def queue_upload (self , registry_name : str , filepath : Union [ str , Path ] , metadata : Optional [dict ] = None ) -> None :
109110 """Queue an upload task."""
110111 self .upload_count += 1
111112 self .task_queue .put ((Action .UPLOAD , (registry_name , filepath , metadata )))
112113 rank_zero_debug (f"Queued upload: { filepath } (pending uploads: { self .upload_count } )" )
113114
114- def queue_remove (self , trainer : "pl.Trainer" , filepath : str ) -> None :
115+ def queue_remove (self , trainer : "pl.Trainer" , filepath : Union [ str , Path ] ) -> None :
115116 """Queue a removal task."""
116117 self .remove_count += 1
117118 self .task_queue .put ((Action .REMOVE , (trainer , filepath )))
@@ -132,15 +133,21 @@ class LitModelCheckpointMixin(ABC):
132133 model_registry : Optional [str ] = None
133134 _model_manager : ModelManager
134135
135- def __init__ (self , model_name : Optional [str ]) -> None :
136- """Initialize with model name."""
137- if not model_name :
136+ def __init__ (self , model_registry : Optional [str ], clear_all_local : bool = False ) -> None :
137+ """Initialize with model name.
138+
139+ Args:
140+ model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
141+ clear_all_local: Whether to clear local models after uploading to the cloud.
142+ """
143+ if not model_registry :
138144 rank_zero_warn (
139145 "The model is not defined so we will continue with LightningModule names and timestamp of now"
140146 )
141147 self ._datetime_stamp = datetime .now ().strftime ("%Y%m%d-%H%M" )
142148 # remove any / from beginning and end of the name
143- self .model_registry = model_name .strip ("/" ) if model_name else None
149+ self .model_registry = model_registry .strip ("/" ) if model_registry else None
150+ self ._clear_all_local = clear_all_local
144151
145152 try : # authenticate before anything else starts
146153 Auth ().authenticate ()
@@ -150,7 +157,7 @@ def __init__(self, model_name: Optional[str]) -> None:
150157 self ._model_manager = ModelManager ()
151158
152159 @rank_zero_only
153- def _upload_model (self , filepath : str , metadata : Optional [dict ] = None ) -> None :
160+ def _upload_model (self , trainer : "pl.Trainer" , filepath : Union [ str , Path ] , metadata : Optional [dict ] = None ) -> None :
154161 if not self .model_registry :
155162 raise RuntimeError (
156163 "Model name is not specified neither updated by `setup` method via Trainer."
@@ -170,11 +177,16 @@ def _upload_model(self, filepath: str, metadata: Optional[dict] = None) -> None:
170177 metadata .update ({"litModels_integration" : ckpt_class .__name__ })
171178 # Add to queue instead of uploading directly
172179 get_model_manager ().queue_upload (registry_name = model_registry , filepath = filepath , metadata = metadata )
180+ if self ._clear_all_local :
181+ get_model_manager ().queue_remove (trainer = trainer , filepath = filepath )
173182
174183 @rank_zero_only
175- def _remove_model (self , trainer : "pl.Trainer" , filepath : str ) -> None :
184+ def _remove_model (self , trainer : "pl.Trainer" , filepath : Union [ str , Path ] ) -> None :
176185 """Remove the local version of the model if requested."""
177- get_model_manager ().queue_remove (trainer , filepath )
186+ if self ._clear_all_local :
187+ # skip the local removal we put it in the queue right after the upload
188+ return
189+ get_model_manager ().queue_remove (trainer = trainer , filepath = filepath )
178190
179191 def default_model_name (self , pl_model : "pl.LightningModule" ) -> str :
180192 """Generate a default model name based on the class name and timestamp."""
@@ -221,15 +233,30 @@ class LightningModelCheckpoint(LitModelCheckpointMixin, _LightningModelCheckpoin
221233 """Lightning ModelCheckpoint with LitModel support.
222234
223235 Args:
224- model_name: Name of the model to upload in format 'organization/teamspace/modelname'
225- args: Additional arguments to pass to the parent class.
226- kwargs: Additional keyword arguments to pass to the parent class.
236+ model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
237+ clear_all_local: Whether to clear local models after uploading to the cloud.
238+ *args: Additional arguments to pass to the parent class.
239+ **kwargs: Additional keyword arguments to pass to the parent class.
227240 """
228241
229- def __init__ (self , * args : Any , model_name : Optional [str ] = None , ** kwargs : Any ) -> None :
242+ def __init__ (
243+ self ,
244+ * args : Any ,
245+ model_name : Optional [str ] = None ,
246+ model_registry : Optional [str ] = None ,
247+ clear_all_local : bool = False ,
248+ ** kwargs : Any ,
249+ ) -> None :
230250 """Initialize the checkpoint with model name and other parameters."""
231251 _LightningModelCheckpoint .__init__ (self , * args , ** kwargs )
232- LitModelCheckpointMixin .__init__ (self , model_name )
252+ if model_name is not None :
253+ rank_zero_warn (
254+ "The 'model_name' argument is deprecated and will be removed in a future version."
255+ " Please use 'model_registry' instead."
256+ )
257+ LitModelCheckpointMixin .__init__ (
258+ self , model_registry = model_registry or model_name , clear_all_local = clear_all_local
259+ )
233260
234261 def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
235262 """Setup the checkpoint callback."""
@@ -240,7 +267,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
240267 """Extend the save checkpoint method to upload the model."""
241268 _LightningModelCheckpoint ._save_checkpoint (self , trainer , filepath )
242269 if trainer .is_global_zero : # Only upload from the main process
243- self ._upload_model (filepath )
270+ self ._upload_model (trainer = trainer , filepath = filepath )
244271
245272 def on_fit_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
246273 """Extend the on_fit_end method to ensure all uploads are completed."""
@@ -251,7 +278,7 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") ->
251278 def _remove_checkpoint (self , trainer : "pl.Trainer" , filepath : str ) -> None :
252279 """Extend the remove checkpoint method to remove the model from the registry."""
253280 if trainer .is_global_zero : # Only remove from the main process
254- self ._remove_model (trainer , filepath )
281+ self ._remove_model (trainer = trainer , filepath = filepath )
255282
256283
257284if _PYTORCHLIGHTNING_AVAILABLE :
@@ -260,15 +287,30 @@ class PytorchLightningModelCheckpoint(LitModelCheckpointMixin, _PytorchLightning
260287 """PyTorch Lightning ModelCheckpoint with LitModel support.
261288
262289 Args:
263- model_name: Name of the model to upload in format 'organization/teamspace/modelname'
290+ model_registry: Name of the model to upload in format 'organization/teamspace/modelname'.
291+ clear_all_local: Whether to clear local models after uploading to the cloud.
264292 args: Additional arguments to pass to the parent class.
265293 kwargs: Additional keyword arguments to pass to the parent class.
266294 """
267295
268- def __init__ (self , * args : Any , model_name : Optional [str ] = None , ** kwargs : Any ) -> None :
296+ def __init__ (
297+ self ,
298+ * args : Any ,
299+ model_name : Optional [str ] = None ,
300+ model_registry : Optional [str ] = None ,
301+ clear_all_local : bool = False ,
302+ ** kwargs : Any ,
303+ ) -> None :
269304 """Initialize the checkpoint with model name and other parameters."""
270305 _PytorchLightningModelCheckpoint .__init__ (self , * args , ** kwargs )
271- LitModelCheckpointMixin .__init__ (self , model_name )
306+ if model_name is not None :
307+ rank_zero_warn (
308+ "The 'model_name' argument is deprecated and will be removed in a future version."
309+ " Please use 'model_registry' instead."
310+ )
311+ LitModelCheckpointMixin .__init__ (
312+ self , model_registry = model_registry or model_name , clear_all_local = clear_all_local
313+ )
272314
273315 def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : str ) -> None :
274316 """Setup the checkpoint callback."""
@@ -279,7 +321,7 @@ def _save_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None:
279321 """Extend the save checkpoint method to upload the model."""
280322 _PytorchLightningModelCheckpoint ._save_checkpoint (self , trainer , filepath )
281323 if trainer .is_global_zero : # Only upload from the main process
282- self ._upload_model (filepath )
324+ self ._upload_model (trainer = trainer , filepath = filepath )
283325
284326 def on_fit_end (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" ) -> None :
285327 """Extend the on_fit_end method to ensure all uploads are completed."""
@@ -290,4 +332,4 @@ def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") ->
290332 def _remove_checkpoint (self , trainer : "pl.Trainer" , filepath : str ) -> None :
291333 """Extend the remove checkpoint method to remove the model from the registry."""
292334 if trainer .is_global_zero : # Only remove from the main process
293- self ._remove_model (trainer , filepath )
335+ self ._remove_model (trainer = trainer , filepath = filepath )
0 commit comments