@@ -15,66 +15,71 @@ class ModelRegistryMixin(ABC):
1515 """Mixin for model registry integration."""
1616
1717 def push_to_registry (
18- self , model_name : Optional [str ] = None , model_version : Optional [str ] = None , temp_folder : Optional [str ] = None
18+ self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Optional [str ] = None
1919 ) -> None :
2020 """Push the model to the registry.
2121
2222 Args:
23- model_name : The name of the model. If not use the class name.
24- model_version : The version of the model. If None, the latest version is used.
23+ name : The name of the model. If not use the class name.
24+ version : The version of the model. If None, the latest version is used.
2525 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
2626 """
2727
2828 @classmethod
2929 def pull_from_registry (
30- cls , model_name : str , model_version : Optional [str ] = None , temp_folder : Optional [str ] = None
30+ cls , name : str , version : Optional [str ] = None , temp_folder : Optional [str ] = None
3131 ) -> object :
3232 """Pull the model from the registry.
3333
3434 Args:
35- model_name : The name of the model.
36- model_version : The version of the model. If None, the latest version is used.
35+ name : The name of the model.
36+ version : The version of the model. If None, the latest version is used.
3737 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
3838 """
3939
4040
41- class PickleRegistryMixin (ABC ):
41+ class PickleRegistryMixin (ModelRegistryMixin ):
4242 """Mixin for pickle registry integration."""
4343
4444 def push_to_registry (
45- self , model_name : Optional [str ] = None , model_version : Optional [str ] = None , temp_folder : Optional [str ] = None
45+ self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Optional [str ] = None
4646 ) -> None :
4747 """Push the model to the registry.
4848
4949 Args:
50- model_name : The name of the model. If not use the class name.
51- model_version : The version of the model. If None, the latest version is used.
50+ name : The name of the model. If not use the class name.
51+ version : The version of the model. If None, the latest version is used.
5252 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
5353 """
54- if model_name is None :
55- model_name = self .__class__ .__name__
54+ if ":" in name :
55+ raise ValueError (f"Invalid model name: '{ name } '. It should not contain ':' associated with version." )
56+ if name is None :
57+ name = model_name = self .__class__ .__name__
58+ else :
59+ model_name = name .split ("/" )[- 1 ]
5660 if temp_folder is None :
57- temp_folder = tempfile .gettempdir ()
61+ temp_folder = tempfile .mkdtemp ()
5862 pickle_path = Path (temp_folder ) / f"{ model_name } .pkl"
5963 with open (pickle_path , "wb" ) as fp :
6064 pickle .dump (self , fp , protocol = pickle .HIGHEST_PROTOCOL )
61- model_registry = f"{ model_name } :{ model_version } " if model_version else model_name
62- upload_model (name = model_registry , model = pickle_path )
65+ if version :
66+ name = f"{ name } :{ version } "
67+ upload_model (name = name , model = pickle_path )
6368
6469 @classmethod
6570 def pull_from_registry (
66- cls , model_name : str , model_version : Optional [str ] = None , temp_folder : Optional [str ] = None
71+ cls , name : str , version : Optional [str ] = None , temp_folder : Optional [str ] = None
6772 ) -> object :
6873 """Pull the model from the registry.
6974
7075 Args:
71- model_name : The name of the model.
72- model_version : The version of the model. If None, the latest version is used.
76+ name : The name of the model.
77+ version : The version of the model. If None, the latest version is used.
7378 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
7479 """
7580 if temp_folder is None :
76- temp_folder = tempfile .gettempdir ()
77- model_registry = f"{ model_name } :{ model_version } " if model_version else model_name
81+ temp_folder = tempfile .mkdtemp ()
82+ model_registry = f"{ name } :{ version } " if version else name
7883 files = download_model (name = model_registry , download_dir = temp_folder )
7984 pkl_files = [f for f in files if f .endswith (".pkl" )]
8085 if not pkl_files :
@@ -89,7 +94,7 @@ def pull_from_registry(
8994 return obj
9095
9196
92- class PyTorchRegistryMixin (ABC ):
97+ class PyTorchRegistryMixin (ModelRegistryMixin ):
9398 """Mixin for PyTorch model registry integration."""
9499
95100 def __post_init__ (self ) -> None :
@@ -101,51 +106,51 @@ def __post_init__(self) -> None:
101106 raise TypeError (f"The model must be a PyTorch `nn.Module` but got: { type (self )} " )
102107
103108 def push_to_registry (
104- self , model_name : Optional [str ] = None , model_version : Optional [str ] = None , temp_folder : Optional [str ] = None
109+ self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Optional [str ] = None
105110 ) -> None :
106111 """Push the model to the registry.
107112
108113 Args:
109- model_name : The name of the model. If not use the class name.
110- model_version : The version of the model. If None, the latest version is used.
114+ name : The name of the model. If not use the class name.
115+ version : The version of the model. If None, the latest version is used.
111116 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
112117 """
113118 import torch
114119
115120 if not isinstance (self , torch .nn .Module ):
116121 raise TypeError (f"The model must be a PyTorch `nn.Module` but got: { type (self )} " )
117122
118- if model_name is None :
119- model_name = self .__class__ .__name__
123+ if name is None :
124+ name = self .__class__ .__name__
120125 if temp_folder is None :
121- temp_folder = tempfile .gettempdir ()
122- torch_path = Path (temp_folder ) / f"{ model_name } .pth"
126+ temp_folder = tempfile .mkdtemp ()
127+ torch_path = Path (temp_folder ) / f"{ name } .pth"
123128 torch .save (self .state_dict (), torch_path )
124129 # todo: dump also object creation arguments so we can dump it and load with model for object instantiation
125- model_registry = f"{ model_name } :{ model_version } " if model_version else model_name
130+ model_registry = f"{ name } :{ version } " if version else name
126131 upload_model (name = model_registry , model = torch_path )
127132
128133 @classmethod
129134 def pull_from_registry (
130135 cls ,
131- model_name : str ,
132- model_version : Optional [str ] = None ,
136+ name : str ,
137+ version : Optional [str ] = None ,
133138 temp_folder : Optional [str ] = None ,
134139 torch_load_kwargs : Optional [dict ] = None ,
135140 ) -> "torch.nn.Module" :
136141 """Pull the model from the registry.
137142
138143 Args:
139- model_name : The name of the model.
140- model_version : The version of the model. If None, the latest version is used.
144+ name : The name of the model.
145+ version : The version of the model. If None, the latest version is used.
141146 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
142147 torch_load_kwargs: Additional arguments to pass to `torch.load()`.
143148 """
144149 import torch
145150
146151 if temp_folder is None :
147- temp_folder = tempfile .gettempdir ()
148- model_registry = f"{ model_name } :{ model_version } " if model_version else model_name
152+ temp_folder = tempfile .mkdtemp ()
153+ model_registry = f"{ name } :{ version } " if version else name
149154 files = download_model (name = model_registry , download_dir = temp_folder )
150155 torch_files = [f for f in files if f .endswith (".pth" )]
151156 if not torch_files :
0 commit comments