1+ import inspect
2+ import json
13import pickle
24import tempfile
35import warnings
46from abc import ABC
57from pathlib import Path
6- from typing import TYPE_CHECKING , Optional , Tuple
8+ from typing import TYPE_CHECKING , Any , Optional , Tuple , Union
79
8- from litmodels import download_model , upload_model
10+ from lightning_utilities .core .rank_zero import rank_zero_warn
11+
12+ from litmodels .io .cloud import download_model_files , upload_model_files
913
1014if TYPE_CHECKING :
1115 import torch
@@ -15,7 +19,7 @@ class ModelRegistryMixin(ABC):
1519 """Mixin for model registry integration."""
1620
1721 def push_to_registry (
18- self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Optional [str ] = None
22+ self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Union [str , Path , None ] = None
1923 ) -> None :
2024 """Push the model to the registry.
2125
@@ -26,7 +30,9 @@ def push_to_registry(
2630 """
2731
2832 @classmethod
29- def pull_from_registry (cls , name : str , version : Optional [str ] = None , temp_folder : Optional [str ] = None ) -> object :
33+ def pull_from_registry (
34+ cls , name : str , version : Optional [str ] = None , temp_folder : Union [str , Path , None ] = None
35+ ) -> object :
3036 """Pull the model from the registry.
3137
3238 Args:
@@ -35,7 +41,9 @@ def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folde
3541 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
3642 """
3743
38- def _setup (self , name : Optional [str ] = None , temp_folder : Optional [str ] = None ) -> Tuple [str , str , str ]:
44+ def _setup (
45+ self , name : Optional [str ] = None , temp_folder : Union [str , Path , None ] = None
46+ ) -> Tuple [str , str , Union [str , Path ]]:
3947 """Parse and validate the model name and temporary folder."""
4048 if name is None :
4149 name = model_name = self .__class__ .__name__
@@ -52,7 +60,7 @@ class PickleRegistryMixin(ModelRegistryMixin):
5260 """Mixin for pickle registry integration."""
5361
5462 def push_to_registry (
55- self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Optional [str ] = None
63+ self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Union [str , Path , None ] = None
5664 ) -> None :
5765 """Push the model to the registry.
5866
@@ -67,10 +75,12 @@ def push_to_registry(
6775 pickle .dump (self , fp , protocol = pickle .HIGHEST_PROTOCOL )
6876 if version :
6977 name = f"{ name } :{ version } "
70- upload_model (name = name , model = pickle_path )
78+ upload_model_files (name = name , path = pickle_path )
7179
7280 @classmethod
73- def pull_from_registry (cls , name : str , version : Optional [str ] = None , temp_folder : Optional [str ] = None ) -> object :
81+ def pull_from_registry (
82+ cls , name : str , version : Optional [str ] = None , temp_folder : Union [str , Path , None ] = None
83+ ) -> object :
7484 """Pull the model from the registry.
7585
7686 Args:
@@ -81,7 +91,7 @@ def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folde
8191 if temp_folder is None :
8292 temp_folder = tempfile .mkdtemp ()
8393 model_registry = f"{ name } :{ version } " if version else name
84- files = download_model (name = model_registry , download_dir = temp_folder )
94+ files = download_model_files (name = model_registry , download_dir = temp_folder )
8595 pkl_files = [f for f in files if f .endswith (".pkl" )]
8696 if not pkl_files :
8797 raise RuntimeError (f"No pickle file found for model: { model_registry } with { files } " )
@@ -98,8 +108,27 @@ def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folde
98108class PyTorchRegistryMixin (ModelRegistryMixin ):
99109 """Mixin for PyTorch model registry integration."""
100110
111+ def __new__ (cls , * args : Any , ** kwargs : Any ) -> "torch.nn.Module" :
112+ """Create a new instance of the class without calling __init__."""
113+ instance = super ().__new__ (cls )
114+
115+ # Get __init__ signature excluding 'self'
116+ init_sig = inspect .signature (cls .__init__ )
117+ params = list (init_sig .parameters .values ())[1 :] # Skip self
118+
119+ # Create temporary signature for binding
120+ temp_sig = init_sig .replace (parameters = params )
121+
122+ # Bind and apply defaults
123+ bound_args = temp_sig .bind (* args , ** kwargs )
124+ bound_args .apply_defaults ()
125+
126+ # Store unified kwargs
127+ instance .__init_kwargs = bound_args .arguments
128+ return instance
129+
101130 def push_to_registry (
102- self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Optional [str ] = None
131+ self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Union [str , Path , None ] = None
103132 ) -> None :
104133 """Push the model to the registry.
105134
@@ -110,22 +139,43 @@ def push_to_registry(
110139 """
111140 import torch
112141
142+ # Ensure that the model is in evaluation mode
113143 if not isinstance (self , torch .nn .Module ):
114144 raise TypeError (f"The model must be a PyTorch `nn.Module` but got: { type (self )} " )
115145
116146 name , model_name , temp_folder = self ._setup (name , temp_folder )
117- torch_path = Path (temp_folder ) / f"{ model_name } .pth"
118- torch .save (self .state_dict (), torch_path )
119- # todo: dump also object creation arguments so we can dump it and load with model for object instantiation
147+
148+ if self .__init_kwargs :
149+ try :
150+ # Save the model arguments to a JSON file
151+ init_kwargs_path = Path (temp_folder ) / f"{ model_name } __init_kwargs.json"
152+ with open (init_kwargs_path , "w" ) as fp :
153+ json .dump (self .__init_kwargs , fp )
154+ except Exception as e :
155+ raise RuntimeError (
156+ f"Failed to save model arguments: { e } ."
157+ " Ensure the model's arguments are JSON serializable or use `PickleRegistryMixin`."
158+ ) from e
159+ elif not hasattr (self , "__init_kwargs" ):
160+ rank_zero_warn (
161+ "The child class is missing `__init_kwargs`."
162+ " Ensure `PyTorchRegistryMixin` is first in the inheritance order"
163+ " or call `PyTorchRegistryMixin.__init__` explicitly in the child class."
164+ )
165+
166+ torch_state_dict_path = Path (temp_folder ) / f"{ model_name } .pth"
167+ torch .save (self .state_dict (), torch_state_dict_path )
120168 model_registry = f"{ name } :{ version } " if version else name
121- upload_model (name = model_registry , model = torch_path )
169+ # todo: consider creating another temp folder and copying these two files
170+ # todo: updating SDK to support uploading just specific files
171+ upload_model_files (name = model_registry , path = temp_folder )
122172
123173 @classmethod
124174 def pull_from_registry (
125175 cls ,
126176 name : str ,
127177 version : Optional [str ] = None ,
128- temp_folder : Optional [str ] = None ,
178+ temp_folder : Union [str , Path , None ] = None ,
129179 torch_load_kwargs : Optional [dict ] = None ,
130180 ) -> "torch.nn.Module" :
131181 """Pull the model from the registry.
@@ -141,7 +191,8 @@ def pull_from_registry(
141191 if temp_folder is None :
142192 temp_folder = tempfile .mkdtemp ()
143193 model_registry = f"{ name } :{ version } " if version else name
144- files = download_model (name = model_registry , download_dir = temp_folder )
194+ files = download_model_files (name = model_registry , download_dir = temp_folder )
195+
145196 torch_files = [f for f in files if f .endswith (".pth" )]
146197 if not torch_files :
147198 raise RuntimeError (f"No torch file found for model: { model_registry } with { files } " )
@@ -153,8 +204,18 @@ def pull_from_registry(
153204 warnings .simplefilter ("ignore" , category = FutureWarning )
154205 state_dict = torch .load (state_dict_path , ** (torch_load_kwargs if torch_load_kwargs else {}))
155206
207+ init_files = [fp for fp in files if fp .endswith ("__init_kwargs.json" )]
208+ if not init_files :
209+ init_kwargs = {}
210+ elif len (init_files ) > 1 :
211+ raise RuntimeError (f"Multiple init files found for model: { model_registry } with { init_files } " )
212+ else :
213+ init_kwargs_path = Path (temp_folder ) / init_files [0 ]
214+ with open (init_kwargs_path ) as fp :
215+ init_kwargs = json .load (fp )
216+
156217 # Create a new model instance without calling __init__
157- instance = cls () # todo: we need to add args used when created dumped model
218+ instance = cls (** init_kwargs )
158219 if not isinstance (instance , torch .nn .Module ):
159220 raise TypeError (f"The model must be a PyTorch `nn.Module` but got: { type (instance )} " )
160221 # Now load the state dict on the instance
0 commit comments