1010from peft import PeftModel
1111from pydantic import BaseModel
1212from sklearn .base import BaseEstimator
13+ from torch import nn
1314from transformers import ( # type: ignore[attr-defined]
1415 AutoModelForSequenceClassification ,
1516 AutoTokenizer ,
1920)
2021
2122from autointent import Embedder , Ranker , VectorIndex
23+ from autointent ._wrappers import BaseTorchModule
2224from autointent .configs import CrossEncoderConfig , EmbedderConfig
2325from autointent .context .optimization_info import Artifact
2426from autointent .schemas import TagsList
2527
2628ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
2729
2830ModuleAttributes : TypeAlias = (
29- ModuleSimpleAttributes | TagsList | np .ndarray | Embedder | VectorIndex | BaseEstimator | Ranker # type: ignore[type-arg]
31+ ModuleSimpleAttributes | TagsList | np .ndarray | Embedder | VectorIndex | BaseEstimator | Ranker | nn . Module # type: ignore[type-arg]
3032)
3133
3234logger = logging .getLogger (__name__ )
@@ -43,6 +45,7 @@ class Dumper:
4345 pydantic_models : str = "pydantic"
4446 hf_models = "hf_models"
4547 hf_tokenizers = "hf_tokenizers"
48+ torch_models = "torch_models"
4649 ptuning_models = "ptuning_models"
4750
4851 @staticmethod
@@ -62,6 +65,7 @@ def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
6265 path / Dumper .pydantic_models ,
6366 path / Dumper .hf_models ,
6467 path / Dumper .hf_tokenizers ,
68+ path / Dumper .torch_models ,
6569 path / Dumper .ptuning_models ,
6670 ]
6771 for subdir in subdirectories :
@@ -139,6 +143,20 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
139143 except Exception as e :
140144 msg = f"Error dumping HF model { key } : { e } "
141145 logger .exception (msg )
146+ elif isinstance (val , BaseTorchModule ):
147+ model_path = path / Dumper .torch_models / key
148+ model_path .mkdir (parents = True , exist_ok = True )
149+ try :
150+ class_info = {
151+ "module" : val .__class__ .__module__ ,
152+ "name" : val .__class__ .__name__ ,
153+ }
154+ with (model_path / "class_info.json" ).open ("w" ) as f :
155+ json .dump (class_info , f )
156+ val .dump (model_path )
157+ except Exception as e :
158+ msg = f"Error dumping torch model { key } : { e } "
159+ logger .exception (msg )
142160 elif isinstance (val , PreTrainedTokenizer | PreTrainedTokenizerFast ):
143161 tokenizer_path = path / Dumper .hf_tokenizers / key
144162 tokenizer_path .mkdir (parents = True , exist_ok = True )
@@ -174,6 +192,7 @@ def load( # noqa: C901, PLR0912, PLR0915
174192 pydantic_models : dict [str , Any ] = {}
175193 hf_models : dict [str , Any ] = {}
176194 hf_tokenizers : dict [str , Any ] = {}
195+ torch_models : dict [str , Any ] = {}
177196
178197 for child in path .iterdir ():
179198 if child .name == Dumper .tags :
@@ -248,6 +267,18 @@ def load( # noqa: C901, PLR0912, PLR0915
248267 except Exception as e : # noqa: PERF203
249268 msg = f"Error loading HF tokenizer { tokenizer_dir .name } : { e } "
250269 logger .exception (msg )
270+ elif child .name == Dumper .torch_models :
271+ try :
272+ for model_dir in child .iterdir ():
273+ with (model_dir / "class_info.json" ).open ("r" ) as f :
274+ class_info = json .load (f )
275+ module = importlib .import_module (class_info ["module" ])
276+ model_class : BaseTorchModule = getattr (module , class_info ["name" ])
277+ model = model_class .load (model_dir )
278+ torch_models [model_dir .name ] = model
279+ except Exception as e :
280+ msg = f"Error loading torch model { model_dir .name } : { e } "
281+ logger .exception (msg )
251282 else :
252283 msg = f"Found unexpected child { child } "
253284 logger .error (msg )
@@ -263,4 +294,5 @@ def load( # noqa: C901, PLR0912, PLR0915
263294 | pydantic_models
264295 | hf_models
265296 | hf_tokenizers
297+ | torch_models
266298 )
0 commit comments