1111from peft import PeftModel
1212from pydantic import BaseModel
1313from sklearn .base import BaseEstimator
14- from torch import nn
1514from transformers import ( # type: ignore[attr-defined]
1615 AutoModelForSequenceClassification ,
1716 AutoTokenizer ,
2120)
2221
2322from autointent import Embedder , Ranker , VectorIndex
24- from autointent ._wrappers import BaseTorchModule
23+ from autointent ._wrappers import BaseTorchModuleWithVocab
2524from autointent .configs import CrossEncoderConfig , EmbedderConfig
2625from autointent .context .optimization_info import Artifact
2726from autointent .schemas import TagsList
2827
2928ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
3029
3130ModuleAttributes : TypeAlias = (
32- ModuleSimpleAttributes | TagsList | np .ndarray | Embedder | VectorIndex | BaseEstimator | Ranker | nn .Module # type: ignore[type-arg]
31+ ModuleSimpleAttributes
32+ | TagsList
33+ | np .ndarray # type: ignore[type-arg]
34+ | Embedder
35+ | VectorIndex
36+ | BaseEstimator
37+ | Ranker
38+ | BaseTorchModuleWithVocab
3339)
3440
3541logger = logging .getLogger (__name__ )
@@ -75,14 +81,21 @@ def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
7581 subdir .mkdir (parents = True , exist_ok = exists_ok )
7682
7783 @staticmethod
78- def dump (obj : Any , path : Path , exists_ok : bool = False , exclude : list [type [Any ]] | None = None ) -> None : # noqa: ANN401, C901, PLR0912, PLR0915
84+ def dump ( # noqa: C901, PLR0912, PLR0915
85+ obj : Any , # noqa: ANN401
86+ path : Path ,
87+ exists_ok : bool = False ,
88+ exclude : list [type [Any ]] | None = None ,
89+ raise_errors : bool = False ,
90+ ) -> None :
7991 """Dump modules attributes to filestystem.
8092
8193 Args:
8294 obj: Object to dump
8395 path: Path to dump to
8496 exists_ok: If True, do not raise an error if the directory already exists
8597 exclude: List of types to exclude from dumping
98+ raise_errors: whether to raise dumping errors or just log
8699 """
87100 attrs : dict [str , ModuleAttributes ] = vars (obj )
88101 simple_attrs = {}
@@ -119,25 +132,29 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
119132 except Exception as e :
120133 msg = f"Error dumping pydantic model { key } : { e } "
121134 logging .exception (msg )
135+ if raise_errors :
136+ raise
122137 elif isinstance (val , PeftModel ):
123138 # dumping peft models is a nightmare...
124139 # this might break with new versions of peft
125140 try :
126141 if val ._is_prompt_learning : # noqa: SLF001
127142 # strategy to save prompt learning models: save prompt encoder and bert classifier separately
128143 model_path = path / Dumper .ptuning_models / key
129- model_path .mkdir (parents = True , exist_ok = True )
144+ model_path .mkdir (parents = True , exist_ok = exists_ok )
130145 val .save_pretrained (str (model_path / "peft" ))
131146 val .base_model .save_pretrained (model_path / "base_model" ) # type: ignore[attr-defined]
132147 else :
133148 # strategy to save lora models: merge adapters and save as usual hugging face model
134149 model_path = path / Dumper .hf_models / key
135- model_path .mkdir (parents = True , exist_ok = True )
150+ model_path .mkdir (parents = True , exist_ok = exists_ok )
136151 merged_model : PreTrainedModel = val .merge_and_unload ()
137152 merged_model .save_pretrained (model_path ) # type: ignore[attr-defined]
138153 except Exception as e :
139154 msg = f"Error dumping PeftModel { key } : { e } "
140155 logger .exception (msg )
156+ if raise_errors :
157+ raise
141158 elif isinstance (val , PreTrainedModel ):
142159 model_path = path / Dumper .hf_models / key
143160 model_path .mkdir (parents = True , exist_ok = True )
@@ -146,7 +163,9 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
146163 except Exception as e :
147164 msg = f"Error dumping HF model { key } : { e } "
148165 logger .exception (msg )
149- elif isinstance (val , BaseTorchModule ):
166+ if raise_errors :
167+ raise
168+ elif isinstance (val , BaseTorchModuleWithVocab ):
150169 model_path = path / Dumper .torch_models / key
151170 model_path .mkdir (parents = True , exist_ok = True )
152171 try :
@@ -160,6 +179,8 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
160179 except Exception as e :
161180 msg = f"Error dumping torch model { key } : { e } "
162181 logger .exception (msg )
182+ if raise_errors :
183+ raise
163184 elif isinstance (val , PreTrainedTokenizer | PreTrainedTokenizerFast ):
164185 tokenizer_path = path / Dumper .hf_tokenizers / key
165186 tokenizer_path .mkdir (parents = True , exist_ok = True )
@@ -168,11 +189,15 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
168189 except Exception as e :
169190 msg = f"Error dumping HF tokenizer { key } : { e } "
170191 logger .exception (msg )
192+ if raise_errors :
193+ raise
171194 elif isinstance (val , CatBoostClassifier ):
172195 val .save_model (str (path / Dumper .catboost_models / key ), format = "cbm" )
173196 else :
174197 msg = f"Attribute { key } of type { type (val )} cannot be dumped to file system."
175198 logger .error (msg )
199+ if raise_errors :
200+ raise TypeError (msg )
176201
177202 with (path / Dumper .simple_attrs ).open ("w" , encoding = "utf-8" ) as file :
178203 json .dump (simple_attrs , file , ensure_ascii = False , indent = 4 )
@@ -185,6 +210,7 @@ def load( # noqa: C901, PLR0912, PLR0915
185210 path : Path ,
186211 embedder_config : EmbedderConfig | None = None ,
187212 cross_encoder_config : CrossEncoderConfig | None = None ,
213+ raise_errors : bool = False ,
188214 ) -> None :
189215 """Load attributes from file system."""
190216 tags : dict [str , Any ] = {}
@@ -250,7 +276,8 @@ def load( # noqa: C901, PLR0912, PLR0915
250276 except Exception as e :
251277 msg = f"Error loading Pydantic model from { model_dir } : { e } "
252278 logger .exception (msg )
253- continue
279+ if raise_errors :
280+ raise
254281 elif child .name == Dumper .ptuning_models :
255282 for model_dir in child .iterdir ():
256283 try :
@@ -259,20 +286,26 @@ def load( # noqa: C901, PLR0912, PLR0915
259286 except Exception as e : # noqa: PERF203
260287 msg = f"Error loading PeftModel { model_dir .name } : { e } "
261288 logger .exception (msg )
289+ if raise_errors :
290+ raise
262291 elif child .name == Dumper .hf_models :
263292 for model_dir in child .iterdir ():
264293 try :
265294 hf_models [model_dir .name ] = AutoModelForSequenceClassification .from_pretrained (model_dir ) # type: ignore[no-untyped-call]
266295 except Exception as e : # noqa: PERF203
267296 msg = f"Error loading HF model { model_dir .name } : { e } "
268297 logger .exception (msg )
298+ if raise_errors :
299+ raise
269300 elif child .name == Dumper .hf_tokenizers :
270301 for tokenizer_dir in child .iterdir ():
271302 try :
272303 hf_tokenizers [tokenizer_dir .name ] = AutoTokenizer .from_pretrained (tokenizer_dir )
273304 except Exception as e : # noqa: PERF203
274305 msg = f"Error loading HF tokenizer { tokenizer_dir .name } : { e } "
275306 logger .exception (msg )
307+ if raise_errors :
308+ raise
276309 elif child .name == Dumper .catboost_models :
277310 for model_file in child .iterdir ():
278311 try :
@@ -288,15 +321,19 @@ def load( # noqa: C901, PLR0912, PLR0915
288321 with (model_dir / "class_info.json" ).open ("r" ) as f :
289322 class_info = json .load (f )
290323 module = importlib .import_module (class_info ["module" ])
291- model_class : BaseTorchModule = getattr (module , class_info ["name" ])
324+ model_class : BaseTorchModuleWithVocab = getattr (module , class_info ["name" ])
292325 model = model_class .load (model_dir )
293326 torch_models [model_dir .name ] = model
294327 except Exception as e :
295328 msg = f"Error loading torch model { model_dir .name } : { e } "
296329 logger .exception (msg )
330+ if raise_errors :
331+ raise
297332 else :
298333 msg = f"Found unexpected child { child } "
299334 logger .error (msg )
335+ if raise_errors :
336+ raise ValueError (msg )
300337
301338 obj .__dict__ .update (
302339 tags
0 commit comments