77import joblib
88import numpy as np
99import numpy .typing as npt
10+ from peft import PeftModel
1011from pydantic import BaseModel
1112from sklearn .base import BaseEstimator
13+ from transformers import (
14+ AutoModelForSequenceClassification ,
15+ AutoTokenizer ,
16+ PreTrainedModel ,
17+ PreTrainedTokenizer ,
18+ PreTrainedTokenizerFast ,
19+ )
1220
1321from autointent import Embedder , Ranker , VectorIndex
1422from autointent .configs import CrossEncoderConfig , EmbedderConfig
@@ -34,6 +42,7 @@ class Dumper:
3442 pydantic_models : str = "pydantic"
3543 hf_models = "hf_models"
3644 hf_tokenizers = "hf_tokenizers"
45+ peft_models = "peft_models"
3746
3847 @staticmethod
3948 def make_subdirectories (path : Path , exists_ok : bool = False ) -> None :
@@ -52,6 +61,7 @@ def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
5261 path / Dumper .pydantic_models ,
5362 path / Dumper .hf_models ,
5463 path / Dumper .hf_tokenizers ,
64+ path / Dumper .peft_models ,
5565 ]
5666 for subdir in subdirectories :
5767 subdir .mkdir (parents = True , exist_ok = exists_ok )
@@ -101,25 +111,34 @@ def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]]
101111 except Exception as e :
102112 msg = f"Error dumping pydantic model { key } : { e } "
103113 logging .exception (msg )
104- elif (key == "_model" or "model" in key .lower ()) and hasattr (val , "save_pretrained" ):
114+ elif isinstance (val , PeftModel ):
115+ try :
116+ if val ._is_prompt_learning : # noqa: SLF001
117+ model_path = path / Dumper .peft_models / key
118+ model_path .mkdir (parents = True , exist_ok = True )
119+ val .save_pretrained (model_path / "peft" ) # save peft config and prompt encoder
120+ val .base_model .save_pretrained (model_path / "base_model" ) # save bert classifier
121+ else :
122+ model_path = path / Dumper .hf_models / key
123+ model_path .mkdir (parents = True , exist_ok = True )
124+ merged_model : PreTrainedModel = val .merge_and_unload ()
125+ merged_model .save_pretrained (model_path )
126+ except Exception as e :
127+ msg = f"Error dumping PeftModel { key } : { e } "
128+ logger .exception (msg )
129+ elif isinstance (val , PreTrainedModel ):
105130 model_path = path / Dumper .hf_models / key
106131 model_path .mkdir (parents = True , exist_ok = True )
107132 try :
108133 val .save_pretrained (model_path )
109- class_info = {"module" : val .__class__ .__module__ , "name" : val .__class__ .__name__ }
110- with (model_path / "class_info.json" ).open ("w" ) as f :
111- json .dump (class_info , f )
112134 except Exception as e :
113135 msg = f"Error dumping HF model { key } : { e } "
114136 logger .exception (msg )
115- elif ( key == "_tokenizer" or "tokenizer" in key . lower ()) and hasattr ( val , "save_pretrained" ):
137+ elif isinstance ( val , PreTrainedTokenizer | PreTrainedTokenizerFast ):
116138 tokenizer_path = path / Dumper .hf_tokenizers / key
117139 tokenizer_path .mkdir (parents = True , exist_ok = True )
118140 try :
119141 val .save_pretrained (tokenizer_path )
120- class_info = {"module" : val .__class__ .__module__ , "name" : val .__class__ .__name__ }
121- with (tokenizer_path / "class_info.json" ).open ("w" ) as f :
122- json .dump (class_info , f )
123142 except Exception as e :
124143 msg = f"Error dumping HF tokenizer { key } : { e } "
125144 logger .exception (msg )
@@ -202,29 +221,25 @@ def load( # noqa: C901, PLR0912, PLR0915
202221 msg = f"Error loading Pydantic model from { model_dir } : { e } "
203222 logger .exception (msg )
204223 continue
224+ elif child .name == Dumper .peft_models :
225+ for model_dir in child .iterdir ():
226+ try :
227+ model = AutoModelForSequenceClassification .from_pretrained (model_dir / "base_model" )
228+ hf_models [model_dir .name ] = PeftModel .from_pretrained (model , model_dir / "peft" )
229+ except Exception as e : # noqa: PERF203
230+ msg = f"Error loading PeftModel { model_dir .name } : { e } "
231+ logger .exception (msg )
205232 elif child .name == Dumper .hf_models :
206233 for model_dir in child .iterdir ():
207234 try :
208- with (model_dir / "class_info.json" ).open ("r" ) as f :
209- class_info = json .load (f )
210-
211- module = __import__ (class_info ["module" ], fromlist = [class_info ["name" ]])
212- model_class = getattr (module , class_info ["name" ])
213-
214- hf_models [model_dir .name ] = model_class .from_pretrained (model_dir )
235+ hf_models [model_dir .name ] = AutoModelForSequenceClassification .from_pretrained (model_dir )
215236 except Exception as e : # noqa: PERF203
216237 msg = f"Error loading HF model { model_dir .name } : { e } "
217238 logger .exception (msg )
218239 elif child .name == Dumper .hf_tokenizers :
219240 for tokenizer_dir in child .iterdir ():
220241 try :
221- with (tokenizer_dir / "class_info.json" ).open ("r" ) as f :
222- class_info = json .load (f )
223-
224- module = __import__ (class_info ["module" ], fromlist = [class_info ["name" ]])
225- tokenizer_class = getattr (module , class_info ["name" ])
226-
227- hf_tokenizers [tokenizer_dir .name ] = tokenizer_class .from_pretrained (tokenizer_dir )
242+ hf_tokenizers [tokenizer_dir .name ] = AutoTokenizer .from_pretrained (tokenizer_dir )
228243 except Exception as e : # noqa: PERF203
229244 msg = f"Error loading HF tokenizer { tokenizer_dir .name } : { e } "
230245 logger .exception (msg )
0 commit comments