1818
1919def upload_model (
2020 name : str ,
21- model : Union [str , Path , "torch.nn.Module" , Any ],
21+ model : Union [str , Path ],
22+ progress_bar : bool = True ,
23+ cloud_account : Optional [str ] = None ,
24+ verbose : Union [bool , int ] = 1 ,
25+ metadata : Optional [Dict [str , str ]] = None ,
26+ ) -> "UploadedModelInfo" :
27+ """Upload a checkpoint to the model store.
28+
29+ Args:
30+ name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
31+ where entity is either your username or the name of an organization you are part of.
32+ model: The model to upload. Can be a path to a checkpoint file or a folder.
33+ progress_bar: Whether to show a progress bar for the upload.
34+ cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
35+ automatically.
36+ verbose: Whether to print some additional information about the uploaded model.
37+ metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
38+
39+ """
40+ if not isinstance (model , (str , Path )):
41+ raise ValueError (
42+ "The `model` argument should be a path to a file or folder, not an python object."
43+ " For smooth integrations with PyTorch model, Lightning model and many more, use `save_model` instead."
44+ )
45+
46+ return upload_model_files (
47+ path = model ,
48+ name = name ,
49+ progress_bar = progress_bar ,
50+ cloud_account = cloud_account ,
51+ verbose = verbose ,
52+ metadata = metadata ,
53+ )
54+
55+
56+ def save_model (
57+ name : str ,
58+ model : Union ["torch.nn.Module" , Any ],
2259 progress_bar : bool = True ,
2360 cloud_account : Optional [str ] = None ,
2461 staging_dir : Optional [str ] = None ,
@@ -30,7 +67,7 @@ def upload_model(
3067 Args:
3168 name: Name of the model to upload. Must be in the format 'organization/teamspace/modelname'
3269 where entity is either your username or the name of an organization you are part of.
33- model: The model to upload. Can be a path to a checkpoint file, a PyTorch model, or a Lightning model.
70+ model: The model to upload. Can be a PyTorch model, or a Lightning model a .
3471 progress_bar: Whether to show a progress bar for the upload.
3572 cloud_account: The name of the cloud account to store the Model in. Only required if it can't be determined
3673 automatically.
@@ -40,14 +77,18 @@ def upload_model(
4077 metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
4178
4279 """
80+ if isinstance (model , (str , Path )):
81+ raise ValueError (
82+ "The `model` argument should be a PyTorch model or a Lightning model, not a path to a file."
83+ " With file or folder path use `upload_model` instead."
84+ )
85+
4386 if not staging_dir :
4487 staging_dir = tempfile .mkdtemp ()
45- if isinstance (model , (str , Path )):
46- path = model
4788 # if LightningModule and isinstance(model, LightningModule):
4889 # path = os.path.join(staging_dir, f"{model.__class__.__name__}.ckpt")
4990 # model.save_checkpoint(path)
50- elif _PYTORCH_AVAILABLE and isinstance (model , torch .jit .ScriptModule ):
91+ if _PYTORCH_AVAILABLE and isinstance (model , torch .jit .ScriptModule ):
5192 path = os .path .join (staging_dir , f"{ model .__class__ .__name__ } .ts" )
5293 model .save (path )
5394 elif _PYTORCH_AVAILABLE and isinstance (model , torch .nn .Module ):
@@ -60,8 +101,12 @@ def upload_model(
60101 path = os .path .join (staging_dir , f"{ model .__class__ .__name__ } .pkl" )
61102 dump_pickle (model = model , path = path )
62103
63- return upload_model_files (
64- path = path ,
104+ if not metadata :
105+ metadata = {}
106+ metadata .update ({"litModels_integration" : "save_model" })
107+
108+ return upload_model (
109+ model = path ,
65110 name = name ,
66111 progress_bar = progress_bar ,
67112 cloud_account = cloud_account ,
0 commit comments