@@ -38,18 +38,22 @@ def _encode(
3838class HuggingFaceSentenceEmbedder (TransformerSentenceEmbedder ):
3939 def __init__ (self , config_string : str , batch_size : int = 128 ):
4040 super ().__init__ (config_string , batch_size )
41+ self .config_string = config_string
4142
4243 @staticmethod
4344 def load (embedder : dict ) -> "HuggingFaceSentenceEmbedder" :
45+ if os .path .exists (embedder ["config_string" ]):
46+ config_string = embedder ["config_string" ]
47+ else :
48+ config_string = request_util .get_model_path (embedder ["config_string" ])
4449 return HuggingFaceSentenceEmbedder (
45- config_string = request_util .get_model_path (embedder ["config_string" ]),
46- batch_size = embedder ["batch_size" ],
50+ config_string = config_string , batch_size = embedder ["batch_size" ]
4751 )
4852
4953 def to_json (self ) -> dict :
5054 return {
5155 "cls" : "HuggingFaceSentenceEmbedder" ,
52- "config_string" : self .model . model_card_data . base_model ,
56+ "config_string" : self .config_string ,
5357 "batch_size" : self .batch_size ,
5458 }
5559
@@ -239,7 +243,9 @@ def _encode(
239243 self , documents : List [Union [str , Doc ]], fit_model : bool
240244 ) -> Generator [List [List [float ]], None , None ]:
241245 for documents_batch in util .batch (documents , self .batch_size ):
242- documents_batch = [self ._trim_length (doc .replace ("\n " , " " )) for doc in documents_batch ]
246+ documents_batch = [
247+ self ._trim_length (doc .replace ("\n " , " " )) for doc in documents_batch
248+ ]
243249 try :
244250 response = self .openai_client .embeddings .create (
245251 input = documents_batch , model = self .model_name
@@ -270,11 +276,13 @@ def dump(self, project_id: str, embedding_id: str) -> None:
270276 export_file .parent .mkdir (parents = True , exist_ok = True )
271277 util .write_json (self .to_json (), export_file , indent = 2 )
272278
273- def _trim_length (self , text : str , max_length : int = 512 ) -> str :
279+ def _trim_length (self , text : str , max_length : int = 512 ) -> str :
274280 tokens = self ._auto_tokenizer (
275281 text ,
276282 truncation = True ,
277283 max_length = max_length ,
278- return_tensors = None # No tensors needed for just truncating
284+ return_tensors = None , # No tensors needed for just truncating
285+ )
286+ return self ._auto_tokenizer .decode (
287+ tokens ["input_ids" ], skip_special_tokens = True
279288 )
280- return self ._auto_tokenizer .decode (tokens ["input_ids" ], skip_special_tokens = True )
0 commit comments