77import joblib
88import numpy as np
99import numpy .typing as npt
10+ from peft import PeftModel
1011from pydantic import BaseModel
1112from sklearn .base import BaseEstimator
13+ from transformers import ( # type: ignore[attr-defined]
14+ AutoModelForSequenceClassification ,
15+ AutoTokenizer ,
16+ PreTrainedModel ,
17+ PreTrainedTokenizer ,
18+ PreTrainedTokenizerFast ,
19+ )
1220
1321from autointent import Embedder , Ranker , VectorIndex
1422from autointent .configs import CrossEncoderConfig , EmbedderConfig
@@ -34,6 +42,7 @@ class Dumper:
3442 pydantic_models : str = "pydantic"
3543 hf_models = "hf_models"
3644 hf_tokenizers = "hf_tokenizers"
45+ ptuning_models = "ptuning_models"
3746
3847 @staticmethod
3948 def make_subdirectories (path : Path , exists_ok : bool = False ) -> None :
@@ -52,6 +61,7 @@ def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
5261 path / Dumper .pydantic_models ,
5362 path / Dumper .hf_models ,
5463 path / Dumper .hf_tokenizers ,
64+ path / Dumper .ptuning_models ,
5565 ]
5666 for subdir in subdirectories :
5767 subdir .mkdir (parents = True , exist_ok = exists_ok )
@@ -101,25 +111,38 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
101111 except Exception as e :
102112 msg = f"Error dumping pydantic model { key } : { e } "
103113 logging .exception (msg )
104- elif (key == "_model" or "model" in key .lower ()) and hasattr (val , "save_pretrained" ):
114+ elif isinstance (val , PeftModel ):
115+ # dumping peft models is a nightmare...
116+ # this might break with new versions of peft
117+ try :
118+ if val ._is_prompt_learning : # noqa: SLF001
119+ # strategy to save prompt learning models: save prompt encoder and bert classifier separately
120+ model_path = path / Dumper .ptuning_models / key
121+ model_path .mkdir (parents = True , exist_ok = True )
122+ val .save_pretrained (str (model_path / "peft" ))
123+ val .base_model .save_pretrained (model_path / "base_model" ) # type: ignore[attr-defined]
124+ else :
125+ # strategy to save lora models: merge adapters and save as usual hugging face model
126+ model_path = path / Dumper .hf_models / key
127+ model_path .mkdir (parents = True , exist_ok = True )
128+ merged_model : PreTrainedModel = val .merge_and_unload ()
129+ merged_model .save_pretrained (model_path ) # type: ignore[attr-defined]
130+ except Exception as e :
131+ msg = f"Error dumping PeftModel { key } : { e } "
132+ logger .exception (msg )
133+ elif isinstance (val , PreTrainedModel ):
105134 model_path = path / Dumper .hf_models / key
106135 model_path .mkdir (parents = True , exist_ok = True )
107136 try :
108- val .save_pretrained (model_path )
109- class_info = {"module" : val .__class__ .__module__ , "name" : val .__class__ .__name__ }
110- with (model_path / "class_info.json" ).open ("w" ) as f :
111- json .dump (class_info , f )
137+ val .save_pretrained (model_path ) # type: ignore[attr-defined]
112138 except Exception as e :
113139 msg = f"Error dumping HF model { key } : { e } "
114140 logger .exception (msg )
115- elif ( key == "_tokenizer" or "tokenizer" in key . lower ()) and hasattr ( val , "save_pretrained" ):
141+ elif isinstance ( val , PreTrainedTokenizer | PreTrainedTokenizerFast ):
116142 tokenizer_path = path / Dumper .hf_tokenizers / key
117143 tokenizer_path .mkdir (parents = True , exist_ok = True )
118144 try :
119- val .save_pretrained (tokenizer_path )
120- class_info = {"module" : val .__class__ .__module__ , "name" : val .__class__ .__name__ }
121- with (tokenizer_path / "class_info.json" ).open ("w" ) as f :
122- json .dump (class_info , f )
145+ val .save_pretrained (tokenizer_path ) # type: ignore[union-attr]
123146 except Exception as e :
124147 msg = f"Error dumping HF tokenizer { key } : { e } "
125148 logger .exception (msg )
@@ -202,29 +225,25 @@ def load( # noqa: C901, PLR0912, PLR0915
202225 msg = f"Error loading Pydantic model from { model_dir } : { e } "
203226 logger .exception (msg )
204227 continue
228+ elif child .name == Dumper .ptuning_models :
229+ for model_dir in child .iterdir ():
230+ try :
231+ model = AutoModelForSequenceClassification .from_pretrained (model_dir / "base_model" )
232+ hf_models [model_dir .name ] = PeftModel .from_pretrained (model , model_dir / "peft" )
233+ except Exception as e : # noqa: PERF203
234+ msg = f"Error loading PeftModel { model_dir .name } : { e } "
235+ logger .exception (msg )
205236 elif child .name == Dumper .hf_models :
206237 for model_dir in child .iterdir ():
207238 try :
208- with (model_dir / "class_info.json" ).open ("r" ) as f :
209- class_info = json .load (f )
210-
211- module = __import__ (class_info ["module" ], fromlist = [class_info ["name" ]])
212- model_class = getattr (module , class_info ["name" ])
213-
214- hf_models [model_dir .name ] = model_class .from_pretrained (model_dir )
239+ hf_models [model_dir .name ] = AutoModelForSequenceClassification .from_pretrained (model_dir )
215240 except Exception as e : # noqa: PERF203
216241 msg = f"Error loading HF model { model_dir .name } : { e } "
217242 logger .exception (msg )
218243 elif child .name == Dumper .hf_tokenizers :
219244 for tokenizer_dir in child .iterdir ():
220245 try :
221- with (tokenizer_dir / "class_info.json" ).open ("r" ) as f :
222- class_info = json .load (f )
223-
224- module = __import__ (class_info ["module" ], fromlist = [class_info ["name" ]])
225- tokenizer_class = getattr (module , class_info ["name" ])
226-
227- hf_tokenizers [tokenizer_dir .name ] = tokenizer_class .from_pretrained (tokenizer_dir )
246+ hf_tokenizers [tokenizer_dir .name ] = AutoTokenizer .from_pretrained (tokenizer_dir )
228247 except Exception as e : # noqa: PERF203
229248 msg = f"Error loading HF tokenizer { tokenizer_dir .name } : { e } "
230249 logger .exception (msg )
0 commit comments