|
| 1 | +import pickle |
| 2 | +import tempfile |
| 3 | +from abc import ABC |
| 4 | +from pathlib import Path |
| 5 | +from typing import Optional |
| 6 | + |
| 7 | +from litmodels import download_model, upload_model |
| 8 | + |
| 9 | + |
| 10 | +class ModelRegistryMixin(ABC): |
| 11 | + """Mixin for model registry integration.""" |
| 12 | + |
| 13 | + def push_to_registry( |
| 14 | + self, model_name: Optional[str] = None, model_version: Optional[str] = None, temp_folder: Optional[str] = None |
| 15 | + ) -> None: |
| 16 | + """Push the model to the registry. |
| 17 | +
|
| 18 | + Args: |
| 19 | + model_name: The name of the model. If not use the class name. |
| 20 | + model_version: The version of the model. If None, the latest version is used. |
| 21 | + temp_folder: The temporary folder to save the model. If None, a default temporary folder is used. |
| 22 | + """ |
| 23 | + |
| 24 | + @classmethod |
| 25 | + def pull_from_registry( |
| 26 | + cls, model_name: str, model_version: Optional[str] = None, temp_folder: Optional[str] = None |
| 27 | + ) -> object: |
| 28 | + """Pull the model from the registry. |
| 29 | +
|
| 30 | + Args: |
| 31 | + model_name: The name of the model. |
| 32 | + model_version: The version of the model. If None, the latest version is used. |
| 33 | + temp_folder: The temporary folder to save the model. If None, a default temporary folder is used. |
| 34 | + """ |
| 35 | + |
| 36 | + |
| 37 | +class PickleRegistryMixin(ABC): |
| 38 | + """Mixin for pickle registry integration.""" |
| 39 | + |
| 40 | + def push_to_registry( |
| 41 | + self, model_name: Optional[str] = None, model_version: Optional[str] = None, temp_folder: Optional[str] = None |
| 42 | + ) -> None: |
| 43 | + """Push the model to the registry. |
| 44 | +
|
| 45 | + Args: |
| 46 | + model_name: The name of the model. If not use the class name. |
| 47 | + model_version: The version of the model. If None, the latest version is used. |
| 48 | + temp_folder: The temporary folder to save the model. If None, a default temporary folder is used. |
| 49 | + """ |
| 50 | + if model_name is None: |
| 51 | + model_name = self.__class__.__name__ |
| 52 | + if temp_folder is None: |
| 53 | + temp_folder = tempfile.gettempdir() |
| 54 | + pickle_path = Path(temp_folder) / f"{model_name}.pkl" |
| 55 | + with open(pickle_path, "wb") as fp: |
| 56 | + pickle.dump(self, fp, protocol=pickle.HIGHEST_PROTOCOL) |
| 57 | + model_registry = f"{model_name}:{model_version}" if model_version else model_name |
| 58 | + upload_model(name=model_registry, model=pickle_path) |
| 59 | + |
| 60 | + @classmethod |
| 61 | + def pull_from_registry( |
| 62 | + cls, model_name: str, model_version: Optional[str] = None, temp_folder: Optional[str] = None |
| 63 | + ) -> object: |
| 64 | + """Pull the model from the registry. |
| 65 | +
|
| 66 | + Args: |
| 67 | + model_name: The name of the model. |
| 68 | + model_version: The version of the model. If None, the latest version is used. |
| 69 | + temp_folder: The temporary folder to save the model. If None, a default temporary folder is used. |
| 70 | + """ |
| 71 | + if temp_folder is None: |
| 72 | + temp_folder = tempfile.gettempdir() |
| 73 | + model_registry = f"{model_name}:{model_version}" if model_version else model_name |
| 74 | + files = download_model(name=model_registry, download_dir=temp_folder) |
| 75 | + pkl_files = [f for f in files if f.endswith(".pkl")] |
| 76 | + if not pkl_files: |
| 77 | + raise RuntimeError(f"No pickle file found for model: {model_registry} with {files}") |
| 78 | + if len(pkl_files) > 1: |
| 79 | + raise RuntimeError(f"Multiple pickle files found for model: {model_registry} with {pkl_files}") |
| 80 | + pkl_path = Path(temp_folder) / pkl_files[0] |
| 81 | + with open(pkl_path, "rb") as fp: |
| 82 | + obj = pickle.load(fp) |
| 83 | + if not isinstance(obj, cls): |
| 84 | + raise RuntimeError(f"Unpickled object is not of type {cls.__name__}: {type(obj)}") |
| 85 | + return obj |
0 commit comments