@@ -77,6 +77,7 @@ class Embedder:
7777 _weights_dir_name : str = "sentence_transformer"
7878 _dump_dir : Path | None = None
7979 _trained : bool = False
80+ _model : SentenceTransformer
8081
8182 def __init__ (self , embedder_config : EmbedderConfig ) -> None :
8283 """Initialize the Embedder.
@@ -97,15 +98,15 @@ def _get_hash(self) -> int:
9798 commit_hash = _get_latest_commit_hash (self .config .model_name )
9899 hasher .update (commit_hash )
99100 else :
100- self .embedding_model = self ._load_model ()
101- for parameter in self .embedding_model .parameters ():
101+ self ._model = self ._load_model ()
102+ for parameter in self ._model .parameters ():
102103 hasher .update (parameter .detach ().cpu ().numpy ())
103104 hasher .update (self .config .tokenizer_config .max_length )
104105 return hasher .intdigest ()
105106
106107 def _load_model (self ) -> SentenceTransformer :
107108 """Load sentence transformers model to device."""
108- if not hasattr (self , "embedding_model " ):
109+ if not hasattr (self , "_model " ):
109110 res = SentenceTransformer (
110111 self .config .model_name ,
111112 device = self .config .device ,
@@ -114,7 +115,7 @@ def _load_model(self) -> SentenceTransformer:
114115 trust_remote_code = self .config .trust_remote_code ,
115116 )
116117 else :
117- res = self .embedding_model
118+ res = self ._model
118119 return res
119120
120121 def train (self , utterances : list [str ], labels : ListOfLabels , config : EmbedderFineTuningConfig ) -> None :
@@ -133,7 +134,8 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
133134 logger .warning (msg )
134135 return
135136
136- self ._load_model ()
137+ self ._model = self ._load_model ()
138+
137139 if config .early_stopping :
138140 x_train , x_val , y_train , y_val = train_test_split (utterances , labels , test_size = 0.1 , random_state = 42 )
139141 tr_ds = Dataset .from_dict ({"text" : x_train , "label" : y_train })
@@ -142,7 +144,7 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
142144 tr_ds = Dataset .from_dict ({"text" : utterances , "label" : labels })
143145 val_ds = None
144146
145- loss = BatchAllTripletLoss (model = self .embedding_model , margin = config .margin )
147+ loss = BatchAllTripletLoss (model = self ._model , margin = config .margin )
146148 with tempfile .TemporaryDirectory () as tmp_dir :
147149 args = SentenceTransformerTrainingArguments (
148150 save_strategy = "epoch" ,
@@ -169,7 +171,7 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
169171 )
170172 )
171173 trainer = SentenceTransformerTrainer (
172- model = self .embedding_model ,
174+ model = self ._model ,
173175 args = args ,
174176 train_dataset = tr_ds ,
175177 eval_dataset = val_ds ,
@@ -181,7 +183,7 @@ def train(self, utterances: list[str], labels: ListOfLabels, config: EmbedderFin
181183
182184 # use temporary path for re-usage
183185 model_path = str (Path (tempfile .mkdtemp ("autointent_embedders" )) / str (uuid4 ()))
184- self .embedding_model .save (model_path )
186+ self ._model .save (model_path )
185187 self .config .model_name = model_path
186188
187189 self ._trained = True
@@ -190,8 +192,8 @@ def clear_ram(self) -> None:
190192 """Move the embedding model to CPU and delete it from memory."""
191193 if hasattr (self , "embedding_model" ):
192194 logger .debug ("Clearing embedder %s from memory" , self .config .model_name )
193- self .embedding_model .cpu ()
194- del self .embedding_model
195+ self ._model .cpu ()
196+ del self ._model
195197 torch .cuda .empty_cache ()
196198
197199 def delete (self ) -> None :
@@ -208,7 +210,7 @@ def dump(self, path: Path) -> None:
208210 """
209211 if self ._trained :
210212 model_path = str ((path / self ._weights_dir_name ).resolve ())
211- self .embedding_model .save (model_path , create_model_card = False )
213+ self ._model .save (model_path , create_model_card = False )
212214 self .config .model_name = model_path
213215
214216 self ._dump_dir = path
@@ -248,6 +250,11 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
248250 Returns:
249251 A numpy array of embeddings.
250252 """
253+ if len (utterances ) == 0 :
254+ msg = "Empty input"
255+ logger .error (msg )
256+ raise ValueError (msg )
257+
251258 prompt = self .config .get_prompt (task_type )
252259
253260 if self .config .use_cache :
@@ -263,7 +270,7 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
263270 logger .debug ("loading embeddings from %s" , str (embeddings_path ))
264271 return np .load (embeddings_path ) # type: ignore[no-any-return]
265272
266- self .embedding_model = self ._load_model ()
273+ self ._model = self ._load_model ()
267274
268275 logger .debug (
269276 "Calculating embeddings with model %s, batch_size=%d, max_seq_length=%s, embedder_device=%s, prompt=%s" ,
@@ -275,9 +282,9 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
275282 )
276283
277284 if self .config .tokenizer_config .max_length is not None :
278- self .embedding_model .max_seq_length = self .config .tokenizer_config .max_length
285+ self ._model .max_seq_length = self .config .tokenizer_config .max_length
279286
280- embeddings = self .embedding_model .encode (
287+ embeddings = self ._model .encode (
281288 utterances ,
282289 convert_to_numpy = True ,
283290 batch_size = self .config .batch_size ,
0 commit comments