@@ -216,38 +216,6 @@ def _parse_ckpt_path(
216216 )
217217 return ckpt_path
218218
219- def _download_model_registry (self , model_name : str , model_version : str ) -> str :
220- model_registry = model_name
221- model_registry += f":{ model_version } " if model_version else ""
222- # download the latest checkpoint from the model registry
223- local_model_dir = os .path .join (self .trainer .default_root_dir , model_registry .replace ("/" , "_" ))
224-
225- if self .trainer .local_rank == 0 :
226- from lightning_sdk .lightning_cloud .login import Auth
227- from litmodels import download_model
228-
229- try : # authenticate before anything else starts
230- auth = Auth ()
231- auth .authenticate ()
232- except Exception :
233- raise ConnectionError ("Unable to authenticate with Lightning Cloud. Check your credentials." )
234-
235- # print(f"Rank {self.trainer.local_rank} downloads model checkpoint '{model_registry}'")
236- model_files = download_model (model_registry , download_dir = local_model_dir )
237- # print(f"Model checkpoint '{model_registry}' was downloaded to '{local_model_dir}'")
238- if not model_files :
239- raise RuntimeError (f"Download model failed - { model_registry } " )
240-
241- # wait for all to catch up
242- self .trainer .strategy .barrier ("_CheckpointConnector._download_model_registry" )
243-
244- # todo: resolve if there are multiple checkpoints
245- folder_files = [fn for fn in os .listdir (local_model_dir ) if fn .endswith (".ckpt" )]
246- if not folder_files :
247- raise RuntimeError (f"Parsing files from downloaded model failed - { model_registry } " )
248- # print(f"local RANK {self.trainer.local_rank}: using model files: {folder_files}")
249- return os .path .join (local_model_dir , folder_files [0 ])
250-
251219 def resume_end (self ) -> None :
252220 """Signal the connector that all states have resumed and memory for the checkpoint object can be released."""
253221 assert self .trainer .state .fn is not None
0 commit comments