@@ -235,15 +235,18 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
235235
236236 return cls (EmbedderConfig (** kwargs ))
237237
238- def embed (self , utterances : list [str ], task_type : TaskTypeEnum | None = None ) -> npt .NDArray [np .float32 ]:
238+ def embed (
239+ self , utterances : list [str ], task_type : TaskTypeEnum | None = None , return_tensors : bool = False
240+ ) -> npt .NDArray [np .float32 ] | torch .Tensor :
239241 """Calculate embeddings for a list of utterances.
240242
241243 Args:
242244 utterances: List of input texts to calculate embeddings for.
243245 task_type: Type of task for which embeddings are calculated.
246+ return_tensors: If True, return a PyTorch tensor; otherwise, return a numpy array.
244247
245248 Returns:
246- A numpy array of embeddings.
249+ A numpy array or PyTorch tensor of embeddings.
247250 """
248251 if len (utterances ) == 0 :
249252 msg = "Empty input"
@@ -263,7 +266,10 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
263266 embeddings_path = _get_embeddings_path (hasher .hexdigest ())
264267 if embeddings_path .exists ():
265268 logger .debug ("loading embeddings from %s" , str (embeddings_path ))
266- return np .load (embeddings_path ) # type: ignore[no-any-return]
269+ embeddings_np = np .load (embeddings_path )
270+ if return_tensors :
271+ return torch .from_numpy (embeddings_np ).to (self .config .device )
272+ return embeddings_np # type: ignore[no-any-return]
267273
268274 self ._model = self ._load_model ()
269275
@@ -281,15 +287,19 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
281287
282288 embeddings = self ._model .encode (
283289 utterances ,
284- convert_to_numpy = True ,
290+ convert_to_numpy = not return_tensors ,
291+ convert_to_tensor = return_tensors ,
285292 batch_size = self .config .batch_size ,
286293 normalize_embeddings = True ,
287294 prompt = prompt ,
288295 )
289296
290297 if self .config .use_cache :
298+ embeddings_to_save = embeddings
299+ if return_tensors :
300+ embeddings_to_save = embeddings .cpu ().numpy ()
291301 embeddings_path .parent .mkdir (parents = True , exist_ok = True )
292- np .save (embeddings_path , embeddings )
302+ np .save (embeddings_path , embeddings_to_save )
293303
294304 return embeddings
295305
0 commit comments