@@ -15,66 +15,67 @@ class ModelRegistryMixin(ABC):
1515 """Mixin for model registry integration."""
1616
1717 def push_to_registry (
18- self , model_name : Optional [str ] = None , model_version : Optional [str ] = None , temp_folder : Optional [str ] = None
18+ self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Optional [str ] = None
1919 ) -> None :
2020 """Push the model to the registry.
2121
2222 Args:
23- model_name : The name of the model. If not use the class name.
24- model_version : The version of the model. If None, the latest version is used.
23+ name : The name of the model. If not use the class name.
24+ version : The version of the model. If None, the latest version is used.
2525 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
2626 """
2727
2828 @classmethod
29- def pull_from_registry (
30- cls , model_name : str , model_version : Optional [str ] = None , temp_folder : Optional [str ] = None
31- ) -> object :
29+ def pull_from_registry (cls , name : str , version : Optional [str ] = None , temp_folder : Optional [str ] = None ) -> object :
3230 """Pull the model from the registry.
3331
3432 Args:
35- model_name : The name of the model.
36- model_version : The version of the model. If None, the latest version is used.
33+ name : The name of the model.
34+ version : The version of the model. If None, the latest version is used.
3735 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
3836 """
3937
4038
41- class PickleRegistryMixin (ABC ):
39+ class PickleRegistryMixin (ModelRegistryMixin ):
4240 """Mixin for pickle registry integration."""
4341
4442 def push_to_registry (
45- self , model_name : Optional [str ] = None , model_version : Optional [str ] = None , temp_folder : Optional [str ] = None
43+ self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Optional [str ] = None
4644 ) -> None :
4745 """Push the model to the registry.
4846
4947 Args:
50- model_name : The name of the model. If not use the class name.
51- model_version : The version of the model. If None, the latest version is used.
48+ name : The name of the model. If not use the class name.
49+ version : The version of the model. If None, the latest version is used.
5250 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
5351 """
54- if model_name is None :
55- model_name = self .__class__ .__name__
52+ if name is None :
53+ name = model_name = self .__class__ .__name__
54+ elif ":" in name :
55+ raise ValueError (f"Invalid model name: '{ name } '. It should not contain ':' associated with version." )
56+ else :
57+ model_name = name .split ("/" )[- 1 ]
5658 if temp_folder is None :
57- temp_folder = tempfile .gettempdir ()
59+ temp_folder = tempfile .mkdtemp ()
5860 pickle_path = Path (temp_folder ) / f"{ model_name } .pkl"
5961 with open (pickle_path , "wb" ) as fp :
6062 pickle .dump (self , fp , protocol = pickle .HIGHEST_PROTOCOL )
61- model_registry = f"{ model_name } :{ model_version } " if model_version else model_name
62- upload_model (name = model_registry , model = pickle_path )
63+ if version :
64+ name = f"{ name } :{ version } "
65+ upload_model (name = name , model = pickle_path )
6366
6467 @classmethod
65- def pull_from_registry (
66- cls , model_name : str , model_version : Optional [str ] = None , temp_folder : Optional [str ] = None
67- ) -> object :
68+ def pull_from_registry (cls , name : str , version : Optional [str ] = None , temp_folder : Optional [str ] = None ) -> object :
6869 """Pull the model from the registry.
6970
7071 Args:
71- model_name : The name of the model.
72- model_version : The version of the model. If None, the latest version is used.
72+ name : The name of the model.
73+ version : The version of the model. If None, the latest version is used.
7374 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
7475 """
7576 if temp_folder is None :
76- temp_folder = tempfile .gettempdir ()
77- model_registry = f"{ model_name } :{ model_version } " if model_version else model_name
77+ temp_folder = tempfile .mkdtemp ()
78+ model_registry = f"{ name } :{ version } " if version else name
7879 files = download_model (name = model_registry , download_dir = temp_folder )
7980 pkl_files = [f for f in files if f .endswith (".pkl" )]
8081 if not pkl_files :
@@ -89,7 +90,7 @@ def pull_from_registry(
8990 return obj
9091
9192
92- class PyTorchRegistryMixin (ABC ):
93+ class PyTorchRegistryMixin (ModelRegistryMixin ):
9394 """Mixin for PyTorch model registry integration."""
9495
9596 def __post_init__ (self ) -> None :
@@ -101,51 +102,51 @@ def __post_init__(self) -> None:
101102 raise TypeError (f"The model must be a PyTorch `nn.Module` but got: { type (self )} " )
102103
103104 def push_to_registry (
104- self , model_name : Optional [str ] = None , model_version : Optional [str ] = None , temp_folder : Optional [str ] = None
105+ self , name : Optional [str ] = None , version : Optional [str ] = None , temp_folder : Optional [str ] = None
105106 ) -> None :
106107 """Push the model to the registry.
107108
108109 Args:
109- model_name : The name of the model. If not use the class name.
110- model_version : The version of the model. If None, the latest version is used.
110+ name : The name of the model. If not use the class name.
111+ version : The version of the model. If None, the latest version is used.
111112 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
112113 """
113114 import torch
114115
115116 if not isinstance (self , torch .nn .Module ):
116117 raise TypeError (f"The model must be a PyTorch `nn.Module` but got: { type (self )} " )
117118
118- if model_name is None :
119- model_name = self .__class__ .__name__
119+ if name is None :
120+ name = self .__class__ .__name__
120121 if temp_folder is None :
121- temp_folder = tempfile .gettempdir ()
122- torch_path = Path (temp_folder ) / f"{ model_name } .pth"
122+ temp_folder = tempfile .mkdtemp ()
123+ torch_path = Path (temp_folder ) / f"{ name } .pth"
123124 torch .save (self .state_dict (), torch_path )
124125 # todo: dump also object creation arguments so we can dump it and load with model for object instantiation
125- model_registry = f"{ model_name } :{ model_version } " if model_version else model_name
126+ model_registry = f"{ name } :{ version } " if version else name
126127 upload_model (name = model_registry , model = torch_path )
127128
128129 @classmethod
129130 def pull_from_registry (
130131 cls ,
131- model_name : str ,
132- model_version : Optional [str ] = None ,
132+ name : str ,
133+ version : Optional [str ] = None ,
133134 temp_folder : Optional [str ] = None ,
134135 torch_load_kwargs : Optional [dict ] = None ,
135136 ) -> "torch.nn.Module" :
136137 """Pull the model from the registry.
137138
138139 Args:
139- model_name : The name of the model.
140- model_version : The version of the model. If None, the latest version is used.
140+ name : The name of the model.
141+ version : The version of the model. If None, the latest version is used.
141142 temp_folder: The temporary folder to save the model. If None, a default temporary folder is used.
142143 torch_load_kwargs: Additional arguments to pass to `torch.load()`.
143144 """
144145 import torch
145146
146147 if temp_folder is None :
147- temp_folder = tempfile .gettempdir ()
148- model_registry = f"{ model_name } :{ model_version } " if model_version else model_name
148+ temp_folder = tempfile .mkdtemp ()
149+ model_registry = f"{ name } :{ version } " if version else name
149150 files = download_model (name = model_registry , download_dir = temp_folder )
150151 torch_files = [f for f in files if f .endswith (".pth" )]
151152 if not torch_files :
0 commit comments