33import warnings
44from abc import ABC
55from pathlib import Path
6- from typing import TYPE_CHECKING , Optional
6+ from typing import TYPE_CHECKING , Optional , Tuple
77
88from litmodels import download_model , upload_model
99
@@ -35,6 +35,18 @@ def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folde
3535 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
3636 """
3737
38+ def _setup (self , name : Optional [str ] = None , temp_folder : Optional [str ] = None ) -> Tuple [str , str , str ]:
39+ """Parse and validate the model name and temporary folder."""
40+ if name is None :
41+ name = model_name = self .__class__ .__name__
42+ elif ":" in name :
43+ raise ValueError (f"Invalid model name: '{ name } '. It should not contain ':' associated with version." )
44+ else :
45+ model_name = name .split ("/" )[- 1 ]
46+ if temp_folder is None :
47+ temp_folder = tempfile .mkdtemp ()
48+ return name , model_name , temp_folder
49+
3850
3951class PickleRegistryMixin (ModelRegistryMixin ):
4052 """Mixin for pickle registry integration."""
@@ -49,14 +61,7 @@ def push_to_registry(
4961 version: The version of the model. If None, the latest version is used.
5062 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
5163 """
52- if name is None :
53- name = model_name = self .__class__ .__name__
54- elif ":" in name :
55- raise ValueError (f"Invalid model name: '{ name } '. It should not contain ':' associated with version." )
56- else :
57- model_name = name .split ("/" )[- 1 ]
58- if temp_folder is None :
59- temp_folder = tempfile .mkdtemp ()
64+ name , model_name , temp_folder = self ._setup (name , temp_folder )
6065 pickle_path = Path (temp_folder ) / f"{ model_name } .pkl"
6166 with open (pickle_path , "wb" ) as fp :
6267 pickle .dump (self , fp , protocol = pickle .HIGHEST_PROTOCOL )
@@ -93,14 +98,6 @@ def pull_from_registry(cls, name: str, version: Optional[str] = None, temp_folde
9398class PyTorchRegistryMixin (ModelRegistryMixin ):
9499 """Mixin for PyTorch model registry integration."""
95100
96- def __post_init__ (self ) -> None :
97- """Post-initialization method to set up the model."""
98- import torch
99-
100- # Ensure that the model is in evaluation mode
101- if not isinstance (self , torch .nn .Module ):
102- raise TypeError (f"The model must be a PyTorch `nn.Module` but got: { type (self )} " )
103-
104101 def push_to_registry (
105102 self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Optional [str ] = None
106103 ) -> None :
@@ -116,11 +113,8 @@ def push_to_registry(
116113 if not isinstance (self , torch .nn .Module ):
117114 raise TypeError (f"The model must be a PyTorch `nn.Module` but got: { type (self )} " )
118115
119- if name is None :
120- name = self .__class__ .__name__
121- if temp_folder is None :
122- temp_folder = tempfile .mkdtemp ()
123- torch_path = Path (temp_folder ) / f"{ name } .pth"
116+ name , model_name , temp_folder = self ._setup (name , temp_folder )
117+ torch_path = Path (temp_folder ) / f"{ model_name } .pth"
124118 torch .save (self .state_dict (), torch_path )
125119 # todo: dump also object creation arguments so we can dump it and load with model for object instantiation
126120 model_registry = f"{ name } :{ version } " if version else name
0 commit comments