22# Licensed under the Apache License, Version 2.0 (the "License");
33# http://www.apache.org/licenses/LICENSE-2.0
44#
5- import os
6- import tempfile
7- from pathlib import Path
8- from typing import TYPE_CHECKING , Optional , Tuple , Union
5+ from typing import Optional , Tuple
96
107from lightning_sdk .api .teamspace_api import UploadedModelInfo
118from lightning_sdk .teamspace import Teamspace
129from lightning_sdk .utils import resolve as sdk_resolvers
13- from lightning_utilities import module_available
14-
15- if TYPE_CHECKING :
16- from torch .nn import Module
17-
18- if module_available ("torch" ):
19- import torch
20- from torch .nn import Module
21- else :
22- torch = None
2310
2411# if module_available("lightning"):
2512# from lightning import LightningModule
@@ -63,60 +50,18 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:
6350 return Teamspace (** teamspaces [requested_teamspace ])
6451
6552
66- def upload_model (
67- model : Union [str , Path , "Module" ],
68- name : str ,
69- progress_bar : bool = True ,
70- cluster_id : Optional [str ] = None ,
71- staging_dir : Optional [str ] = None ,
72- ) -> UploadedModelInfo :
73- """Upload a checkpoint to the model store.
74-
75- Args:
76- model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
77- name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
78- where entity is either your username or the name of an organization you are part of.
79- progress_bar: Whether to show a progress bar for the upload.
80- cluster_id: The name of the cluster to use. Only required if it can't be determined
81- automatically.
82- staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
83- be created and used.
84-
85- """
86- if not staging_dir :
87- staging_dir = tempfile .mkdtemp ()
88- # if LightningModule and isinstance(model, LightningModule):
89- # path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
90- # model.save_checkpoint(path)
91- if torch and isinstance (model , Module ):
92- path = os .path .join (staging_dir , f"{ model .__class__ .__name__ } .pth" )
93- torch .save (model .state_dict (), path )
94- elif isinstance (model , str ):
95- path = model
96- elif isinstance (model , Path ):
97- path = str (model )
98- else :
99- raise ValueError (f"Unsupported model type { type (model )} " )
100- return upload_model_files (
101- path = path ,
102- name = name ,
103- progress_bar = progress_bar ,
104- cluster_id = cluster_id ,
105- )
106-
107-
10853def upload_model_files (
109- path : str ,
11054 name : str ,
55+ path : str ,
11156 progress_bar : bool = True ,
11257 cluster_id : Optional [str ] = None ,
11358) -> UploadedModelInfo :
11459 """Upload a local checkpoint file to the model store.
11560
11661 Args:
117- path: Path to the model file to upload.
118- name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
62+ name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
11963 where entity is either your username or the name of an organization you are part of.
64+ path: Path to the model file to upload.
12065 progress_bar: Whether to show a progress bar for the upload.
12166 cluster_id: The name of the cluster to use. Only required if it can't be determined
12267 automatically.
@@ -132,15 +77,15 @@ def upload_model_files(
13277 )
13378
13479
135- def download_model (
80+ def download_model_files (
13681 name : str ,
13782 download_dir : str = "." ,
13883 progress_bar : bool = True ,
13984) -> str :
14085 """Download a checkpoint from the model store.
14186
14287 Args:
143- name: Name tag of the model to download. Must be in the format 'organization/teamspace/modelname'
88+ name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
14489 where entity is either your username or the name of an organization you are part of.
14590 download_dir: A path to directory where the model should be downloaded. Defaults
14691 to the current working directory.
0 commit comments