Skip to content

Commit 3e3a588

Browse files
committed
_download_model_registry
1 parent 1c149bd commit 3e3a588

File tree

1 file changed

+0
-32
lines changed

1 file changed

+0
-32
lines changed

src/lightning/pytorch/trainer/connectors/checkpoint_connector.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -216,38 +216,6 @@ def _parse_ckpt_path(
216216
)
217217
return ckpt_path
218218

219-
def _download_model_registry(self, model_name: str, model_version: str) -> str:
220-
model_registry = model_name
221-
model_registry += f":{model_version}" if model_version else ""
222-
# download the latest checkpoint from the model registry
223-
local_model_dir = os.path.join(self.trainer.default_root_dir, model_registry.replace("/", "_"))
224-
225-
if self.trainer.local_rank == 0:
226-
from lightning_sdk.lightning_cloud.login import Auth
227-
from litmodels import download_model
228-
229-
try: # authenticate before anything else starts
230-
auth = Auth()
231-
auth.authenticate()
232-
except Exception:
233-
raise ConnectionError("Unable to authenticate with Lightning Cloud. Check your credentials.")
234-
235-
# print(f"Rank {self.trainer.local_rank} downloads model checkpoint '{model_registry}'")
236-
model_files = download_model(model_registry, download_dir=local_model_dir)
237-
# print(f"Model checkpoint '{model_registry}' was downloaded to '{local_model_dir}'")
238-
if not model_files:
239-
raise RuntimeError(f"Download model failed - {model_registry}")
240-
241-
# wait for all to catch up
242-
self.trainer.strategy.barrier("_CheckpointConnector._download_model_registry")
243-
244-
# todo: resolve if there are multiple checkpoints
245-
folder_files = [fn for fn in os.listdir(local_model_dir) if fn.endswith(".ckpt")]
246-
if not folder_files:
247-
raise RuntimeError(f"Parsing files from downloaded model failed - {model_registry}")
248-
# print(f"local RANK {self.trainer.local_rank}: using model files: {folder_files}")
249-
return os.path.join(local_model_dir, folder_files[0])
250-
251219
def resume_end(self) -> None:
252220
"""Signal the connector that all states have resumed and memory for the checkpoint object can be released."""
253221
assert self.trainer.state.fn is not None

0 commit comments

Comments
 (0)