22# Licensed under the Apache License, Version 2.0 (the "License");
33# http://www.apache.org/licenses/LICENSE-2.0
44#
5-
6- from typing import Optional , Tuple
5+ import os
6+ import tempfile
7+ from pathlib import Path
8+ from typing import TYPE_CHECKING , Optional , Tuple , Union
79
810from lightning_sdk .api .teamspace_api import UploadedModelInfo
911from lightning_sdk .teamspace import Teamspace
1012from lightning_sdk .utils import resolve as sdk_resolvers
13+ from lightning_utilities import module_available
14+
15+ if TYPE_CHECKING :
16+ from torch import nn
17+
18+ if module_available ("torch" ):
19+ import torch
20+ from torch import nn
21+ else :
22+ torch = None
23+
24+ # if module_available("lightning"):
25+ # from lightning import LightningModule
26+ # elif module_available("pytorch_lightning"):
27+ # from pytorch_lightning import LightningModule
28+ # else:
29+ # LightningModule = None
1130
1231
1332def _parse_name (name : str ) -> Tuple [str , str , str ]:
@@ -45,6 +64,48 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:
4564
4665
4766def upload_model (
67+ model : Union [str , Path , nn .Module ],
68+ name : str ,
69+ progress_bar : bool = True ,
70+ cluster_id : Optional [str ] = None ,
71+ staging_dir : Optional [str ] = None ,
72+ ) -> UploadedModelInfo :
73+ """Upload a local checkpoint file to the model store.
74+
75+ Args:
76+ model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
77+ name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
78+ where entity is either your username or the name of an organization you are part of.
79+ progress_bar: Whether to show a progress bar for the upload.
80+ cluster_id: The name of the cluster to use. Only required if it can't be determined
81+ automatically.
82+ staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
83+ be created and used.
84+
85+ """
86+ if not staging_dir :
87+ staging_dir = tempfile .mkdtemp ()
88+ # if LightningModule and isinstance(model, LightningModule):
89+ # path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
90+ # model.save_checkpoint(path)
91+ elif torch and isinstance (model , nn .Module ):
92+ path = os .path .join (staging_dir , f"{ model .__class__ .__name__ } .pth" )
93+ torch .save (model .state_dict (), path )
94+ elif isinstance (model , str ):
95+ path = model
96+ elif isinstance (model , Path ):
97+ path = str (model )
98+ else :
99+ raise ValueError (f"Unsupported model type { type (model )} " )
100+ return upload_model_files (
101+ path = path ,
102+ name = name ,
103+ progress_bar = progress_bar ,
104+ cluster_id = cluster_id ,
105+ )
106+
107+
108+ def upload_model_files (
48109 path : str ,
49110 name : str ,
50111 progress_bar : bool = True ,
@@ -71,7 +132,7 @@ def upload_model(
71132 )
72133
73134
74- def download_model (
135+ def download_model_files (
75136 name : str ,
76137 download_dir : str = "." ,
77138 progress_bar : bool = True ,
0 commit comments