22# Licensed under the Apache License, Version 2.0 (the "License");
33# http://www.apache.org/licenses/LICENSE-2.0
44#
5- import os
6- import tempfile
7- from pathlib import Path
8- from typing import TYPE_CHECKING , Optional , Tuple , Union
5+ from typing import TYPE_CHECKING , Optional , Tuple
96
107from lightning_sdk .api .teamspace_api import UploadedModelInfo
118from lightning_sdk .teamspace import Teamspace
129from lightning_sdk .utils import resolve as sdk_resolvers
1310from lightning_utilities import module_available
1411
1512if TYPE_CHECKING :
16- from torch . nn import Module
13+ pass
1714
1815if module_available ("torch" ):
19- import torch
20- from torch .nn import Module
16+ pass
2117else :
2218 torch = None
2319
@@ -63,60 +59,18 @@ def _get_teamspace(name: str, organization: str) -> Teamspace:
6359 return Teamspace (** teamspaces [requested_teamspace ])
6460
6561
66- def upload_model (
67- model : Union [str , Path , "Module" ],
68- name : str ,
69- progress_bar : bool = True ,
70- cluster_id : Optional [str ] = None ,
71- staging_dir : Optional [str ] = None ,
72- ) -> UploadedModelInfo :
73- """Upload a checkpoint to the model store.
74-
75- Args:
76- model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
77- name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
78- where entity is either your username or the name of an organization you are part of.
79- progress_bar: Whether to show a progress bar for the upload.
80- cluster_id: The name of the cluster to use. Only required if it can't be determined
81- automatically.
82- staging_dir: A directory where the model can be saved temporarily. If not provided, a temporary directory will
83- be created and used.
84-
85- """
86- if not staging_dir :
87- staging_dir = tempfile .mkdtemp ()
88- # if LightningModule and isinstance(model, LightningModule):
89- # path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
90- # model.save_checkpoint(path)
91- if torch and isinstance (model , Module ):
92- path = os .path .join (staging_dir , f"{ model .__class__ .__name__ } .pth" )
93- torch .save (model .state_dict (), path )
94- elif isinstance (model , str ):
95- path = model
96- elif isinstance (model , Path ):
97- path = str (model )
98- else :
99- raise ValueError (f"Unsupported model type { type (model )} " )
100- return upload_model_files (
101- path = path ,
102- name = name ,
103- progress_bar = progress_bar ,
104- cluster_id = cluster_id ,
105- )
106-
107-
10862def upload_model_files (
109- path : str ,
11063 name : str ,
64+ path : str ,
11165 progress_bar : bool = True ,
11266 cluster_id : Optional [str ] = None ,
11367) -> UploadedModelInfo :
11468 """Upload a local checkpoint file to the model store.
11569
11670 Args:
117- path: Path to the model file to upload.
118- name: Name tag of the model to upload. Must be in the format 'organization/teamspace/modelname'
71+ name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
11972 where entity is either your username or the name of an organization you are part of.
73+ path: Path to the model file to upload.
12074 progress_bar: Whether to show a progress bar for the upload.
12175 cluster_id: The name of the cluster to use. Only required if it can't be determined
12276 automatically.
@@ -132,15 +86,15 @@ def upload_model_files(
13286 )
13387
13488
135- def download_model (
89+ def download_model_file (
13690 name : str ,
13791 download_dir : str = "." ,
13892 progress_bar : bool = True ,
13993) -> str :
14094 """Download a checkpoint from the model store.
14195
14296 Args:
143- name: Name tag of the model to download. Must be in the format 'organization/teamspace/modelname'
97+ name: Name of the model to download. Must be in the format 'organization/teamspace/modelname'
14498 where entity is either your username or the name of an organization you are part of.
14599 download_dir: A path to directory where the model should be downloaded. Defaults
146100 to the current working directory.
0 commit comments