@@ -33,6 +33,8 @@ class Dumper:
3333 estimators = "estimators"
3434 cross_encoders = "cross_encoders"
3535 pydantic_models : str = "pydantic"
36+ hf_models = "hf_models"
37+ hf_tokenizers = "hf_tokenizers"
3638
3739 @staticmethod
3840 def make_subdirectories (path : Path ) -> None :
@@ -48,12 +50,14 @@ def make_subdirectories(path: Path) -> None:
4850 path / Dumper .estimators ,
4951 path / Dumper .cross_encoders ,
5052 path / Dumper .pydantic_models ,
53+ path / Dumper .hf_models ,
54+ path / Dumper .hf_tokenizers ,
5155 ]
5256 for subdir in subdirectories :
5357 subdir .mkdir (parents = True , exist_ok = True )
5458
5559 @staticmethod
56- def dump (obj : Any , path : Path ) -> None : # noqa: ANN401, C901
60+ def dump (obj : Any , path : Path ) -> None : # noqa: ANN401, C901, PLR0912, PLR0915
5761 """Dump modules attributes to filestystem.
5862
5963 Args:
@@ -89,6 +93,28 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
8993 except Exception as e :
9094 msg = f"Error dumping pydantic model { key } : { e } "
9195 logging .exception (msg )
96+ elif (key == "_model" or "model" in key .lower ()) and hasattr (val , "save_pretrained" ):
97+ model_path = path / Dumper .hf_models / key
98+ model_path .mkdir (parents = True , exist_ok = True )
99+ try :
100+ val .save_pretrained (model_path )
101+ class_info = {"module" : val .__class__ .__module__ , "name" : val .__class__ .__name__ }
102+ with (model_path / "class_info.json" ).open ("w" ) as f :
103+ json .dump (class_info , f )
104+ except Exception as e :
105+ msg = f"Error dumping HF model { key } : { e } "
106+ logger .exception (msg )
107+ elif (key == "_tokenizer" or "tokenizer" in key .lower ()) and hasattr (val , "save_pretrained" ):
108+ tokenizer_path = path / Dumper .hf_tokenizers / key
109+ tokenizer_path .mkdir (parents = True , exist_ok = True )
110+ try :
111+ val .save_pretrained (tokenizer_path )
112+ class_info = {"module" : val .__class__ .__module__ , "name" : val .__class__ .__name__ }
113+ with (tokenizer_path / "class_info.json" ).open ("w" ) as f :
114+ json .dump (class_info , f )
115+ except Exception as e :
116+ msg = f"Error dumping HF tokenizer { key } : { e } "
117+ logger .exception (msg )
92118 else :
93119 msg = f"Attribute { key } of type { type (val )} cannot be dumped to file system."
94120 logger .error (msg )
@@ -114,6 +140,8 @@ def load( # noqa: PLR0912, C901, PLR0915
114140 estimators : dict [str , Any ] = {}
115141 cross_encoders : dict [str , Any ] = {}
116142 pydantic_models : dict [str , Any ] = {}
143+ hf_models : dict [str , Any ] = {}
144+ hf_tokenizers : dict [str , Any ] = {}
117145
118146 for child in path .iterdir ():
119147 if child .name == Dumper .tags :
@@ -151,7 +179,6 @@ def load( # noqa: PLR0912, C901, PLR0915
151179 sig = inspect .signature (obj .__init__ )
152180 if variable_name in sig .parameters :
153181 model_type = sig .parameters [variable_name ].annotation
154-
155182 if model_type is None :
156183 msg = f"No type annotation found for { variable_name } "
157184 logger .error (msg )
@@ -174,9 +201,45 @@ def load( # noqa: PLR0912, C901, PLR0915
174201 continue
175202
176203 pydantic_models [variable_name ] = model_type (** content )
204+ elif child .name == Dumper .hf_models :
205+ for model_dir in child .iterdir ():
206+ try :
207+ with (model_dir / "class_info.json" ).open ("r" ) as f :
208+ class_info = json .load (f )
209+
210+ module = __import__ (class_info ["module" ], fromlist = [class_info ["name" ]])
211+ model_class = getattr (module , class_info ["name" ])
212+
213+ hf_models [model_dir .name ] = model_class .from_pretrained (model_dir )
214+ except Exception as e : # noqa: PERF203
215+ msg = f"Error loading HF model { model_dir .name } : { e } "
216+ logger .exception (msg )
217+ elif child .name == Dumper .hf_tokenizers :
218+ for tokenizer_dir in child .iterdir ():
219+ try :
220+ with (tokenizer_dir / "class_info.json" ).open ("r" ) as f :
221+ class_info = json .load (f )
222+
223+ module = __import__ (class_info ["module" ], fromlist = [class_info ["name" ]])
224+ tokenizer_class = getattr (module , class_info ["name" ])
225+
226+ hf_tokenizers [tokenizer_dir .name ] = tokenizer_class .from_pretrained (tokenizer_dir )
227+ except Exception as e : # noqa: PERF203
228+ msg = f"Error loading HF tokenizer { tokenizer_dir .name } : { e } "
229+ logger .exception (msg )
177230 else :
178231 msg = f"Found unexpected child { child } "
179232 logger .error (msg )
233+
180234 obj .__dict__ .update (
181- tags | simple_attrs | arrays | embedders | indexes | estimators | cross_encoders | pydantic_models
235+ tags
236+ | simple_attrs
237+ | arrays
238+ | embedders
239+ | indexes
240+ | estimators
241+ | cross_encoders
242+ | pydantic_models
243+ | hf_models
244+ | hf_tokenizers
182245 )
0 commit comments