|
1 | 1 | import pickle |
2 | 2 | import tempfile |
| 3 | +import warnings |
3 | 4 | from abc import ABC |
4 | 5 | from pathlib import Path |
5 | | -from typing import Optional |
| 6 | +from typing import TYPE_CHECKING, Optional |
6 | 7 |
|
7 | 8 | from litmodels import download_model, upload_model |
8 | 9 |
|
| 10 | +if TYPE_CHECKING: |
| 11 | + import torch |
| 12 | + |
9 | 13 |
|
10 | 14 | class ModelRegistryMixin(ABC): |
11 | 15 | """Mixin for model registry integration.""" |
@@ -83,3 +87,81 @@ def pull_from_registry( |
83 | 87 | if not isinstance(obj, cls): |
84 | 88 | raise RuntimeError(f"Unpickled object is not of type {cls.__name__}: {type(obj)}") |
85 | 89 | return obj |
| 90 | + |
| 91 | + |
| 92 | +class PyTorchRegistryMixin(ABC): |
| 93 | + """Mixin for PyTorch model registry integration.""" |
| 94 | + |
| 95 | + def __post_init__(self) -> None: |
| 96 | + """Post-initialization method to set up the model.""" |
| 97 | + import torch |
| 98 | + |
| 99 | + # Ensure that the model is in evaluation mode |
| 100 | + if not isinstance(self, torch.nn.Module): |
| 101 | + raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}") |
| 102 | + |
| 103 | + def push_to_registry( |
| 104 | + self, model_name: Optional[str] = None, model_version: Optional[str] = None, temp_folder: Optional[str] = None |
| 105 | + ) -> None: |
| 106 | + """Push the model to the registry. |
| 107 | +
|
| 108 | + Args: |
| 109 | + model_name: The name of the model. If not use the class name. |
| 110 | + model_version: The version of the model. If None, the latest version is used. |
| 111 | + temp_folder: The temporary folder to save the model. If None, a default temporary folder is used. |
| 112 | + """ |
| 113 | + import torch |
| 114 | + |
| 115 | + if not isinstance(self, torch.nn.Module): |
| 116 | + raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(self)}") |
| 117 | + |
| 118 | + if model_name is None: |
| 119 | + model_name = self.__class__.__name__ |
| 120 | + if temp_folder is None: |
| 121 | + temp_folder = tempfile.gettempdir() |
| 122 | + torch_path = Path(temp_folder) / f"{model_name}.pth" |
| 123 | + torch.save(self.state_dict(), torch_path) |
| 124 | + # todo: dump also object creation arguments so we can dump it and load with model for object instantiation |
| 125 | + model_registry = f"{model_name}:{model_version}" if model_version else model_name |
| 126 | + upload_model(name=model_registry, model=torch_path) |
| 127 | + |
| 128 | + @classmethod |
| 129 | + def pull_from_registry( |
| 130 | + cls, |
| 131 | + model_name: str, |
| 132 | + model_version: Optional[str] = None, |
| 133 | + temp_folder: Optional[str] = None, |
| 134 | + torch_load_kwargs: Optional[dict] = None, |
| 135 | + ) -> "torch.nn.Module": |
| 136 | + """Pull the model from the registry. |
| 137 | +
|
| 138 | + Args: |
| 139 | + model_name: The name of the model. |
| 140 | + model_version: The version of the model. If None, the latest version is used. |
| 141 | + temp_folder: The temporary folder to save the model. If None, a default temporary folder is used. |
| 142 | + torch_load_kwargs: Additional arguments to pass to `torch.load()`. |
| 143 | + """ |
| 144 | + import torch |
| 145 | + |
| 146 | + if temp_folder is None: |
| 147 | + temp_folder = tempfile.gettempdir() |
| 148 | + model_registry = f"{model_name}:{model_version}" if model_version else model_name |
| 149 | + files = download_model(name=model_registry, download_dir=temp_folder) |
| 150 | + torch_files = [f for f in files if f.endswith(".pth")] |
| 151 | + if not torch_files: |
| 152 | + raise RuntimeError(f"No torch file found for model: {model_registry} with {files}") |
| 153 | + if len(torch_files) > 1: |
| 154 | + raise RuntimeError(f"Multiple torch files found for model: {model_registry} with {torch_files}") |
| 155 | + state_dict_path = Path(temp_folder) / torch_files[0] |
| 156 | + # ignore future warning about changed default |
| 157 | + with warnings.catch_warnings(): |
| 158 | + warnings.simplefilter("ignore", category=FutureWarning) |
| 159 | + state_dict = torch.load(state_dict_path, **(torch_load_kwargs if torch_load_kwargs else {})) |
| 160 | + |
| 161 | + # Create a new model instance without calling __init__ |
| 162 | + instance = cls() # todo: we need to add args used when created dumped model |
| 163 | + if not isinstance(instance, torch.nn.Module): |
| 164 | + raise TypeError(f"The model must be a PyTorch `nn.Module` but got: {type(instance)}") |
| 165 | + # Now load the state dict on the instance |
| 166 | + instance.load_state_dict(state_dict, strict=True) |
| 167 | + return instance |
0 commit comments