55import warnings
66from abc import ABC
77from pathlib import Path
8- from typing import TYPE_CHECKING , Any , Optional , Tuple , Union
8+ from typing import TYPE_CHECKING , Any , List , Optional , Tuple , Union
99
1010from lightning_utilities .core .rank_zero import rank_zero_warn
1111
@@ -55,27 +55,45 @@ def _setup(
5555 temp_folder = tempfile .mkdtemp ()
5656 return name , model_name , temp_folder
5757
58+ def _upload_model_files (
59+ self , name : str , path : Union [str , Path , List [Union [str , Path ]]], metadata : Optional [dict ] = None
60+ ) -> None :
61+ """Upload the model files to the registry."""
62+ if not metadata :
63+ metadata = {}
64+ # Add the integration name to the metadata
65+ mro = inspect .getmro (type (self ))
66+ abc_index = mro .index (ModelRegistryMixin )
67+ mixin_class = mro [abc_index - 1 ]
68+ metadata .update ({"litModels_integration" : mixin_class .__name__ })
69+ upload_model_files (name = name , path = path , metadata = metadata )
70+
5871
5972class PickleRegistryMixin (ModelRegistryMixin ):
6073 """Mixin for pickle registry integration."""
6174
6275 def upload_model (
63- self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Union [str , Path , None ] = None
76+ self ,
77+ name : Optional [str ] = None ,
78+ version : Optional [str ] = None ,
79+ temp_folder : Union [str , Path , None ] = None ,
80+ metadata : Optional [dict ] = None ,
6481 ) -> None :
6582 """Push the model to the registry.
6683
6784 Args:
6885 name: The name of the model. If not use the class name.
6986 version: The version of the model. If None, the latest version is used.
7087 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
88+ metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
7189 """
7290 name , model_name , temp_folder = self ._setup (name , temp_folder )
7391 pickle_path = Path (temp_folder ) / f"{ model_name } .pkl"
7492 with open (pickle_path , "wb" ) as fp :
7593 pickle .dump (self , fp , protocol = pickle .HIGHEST_PROTOCOL )
7694 if version :
7795 name = f"{ name } :{ version } "
78- upload_model_files (name = name , path = pickle_path )
96+ self . _upload_model_files (name = name , path = pickle_path , metadata = metadata )
7997
8098 @classmethod
8199 def download_model (
@@ -128,14 +146,19 @@ def __new__(cls, *args: Any, **kwargs: Any) -> "torch.nn.Module":
128146 return instance
129147
130148 def upload_model (
131- self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Union [str , Path , None ] = None
149+ self ,
150+ name : Optional [str ] = None ,
151+ version : Optional [str ] = None ,
152+ temp_folder : Union [str , Path , None ] = None ,
153+ metadata : Optional [dict ] = None ,
132154 ) -> None :
133155 """Push the model to the registry.
134156
135157 Args:
136158 name: The name of the model. If not use the class name.
137159 version: The version of the model. If None, the latest version is used.
138160 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
161+ metadata: Optional metadata to attach to the model. If not provided, a default metadata will be used.
139162 """
140163 import torch
141164
@@ -145,17 +168,18 @@ def upload_model(
145168
146169 name , model_name , temp_folder = self ._setup (name , temp_folder )
147170
171+ init_kwargs_path = None
148172 if self .__init_kwargs :
149173 try :
150174 # Save the model arguments to a JSON file
151175 init_kwargs_path = Path (temp_folder ) / f"{ model_name } __init_kwargs.json"
152176 with open (init_kwargs_path , "w" ) as fp :
153177 json .dump (self .__init_kwargs , fp )
154- except Exception as e :
178+ except Exception as ex :
155179 raise RuntimeError (
156- f"Failed to save model arguments: { e } ."
180+ f"Failed to save model arguments: { ex } ."
157181 " Ensure the model's arguments are JSON serializable or use `PickleRegistryMixin`."
158- ) from e
182+ ) from ex
159183 elif not hasattr (self , "__init_kwargs" ):
160184 rank_zero_warn (
161185 "The child class is missing `__init_kwargs`."
@@ -168,7 +192,10 @@ def upload_model(
168192 model_registry = f"{ name } :{ version } " if version else name
169193 # todo: consider creating another temp folder and copying these two files
170194 # todo: updating SDK to support uploading just specific files
171- upload_model_files (name = model_registry , path = [torch_state_dict_path , init_kwargs_path ])
195+ uploaded_files = [torch_state_dict_path ]
196+ if init_kwargs_path :
197+ uploaded_files .append (init_kwargs_path )
198+ self ._upload_model_files (name = model_registry , path = uploaded_files , metadata = metadata )
172199
173200 @classmethod
174201 def download_model (
0 commit comments