Skip to content

Commit eced2fe

Browse files
committed
return_tensors
1 parent 190a824 commit eced2fe

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

src/autointent/_wrappers/embedder.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,15 +235,18 @@ def load(cls, path: Path | str, override_config: EmbedderConfig | None = None) -
235235

236236
return cls(EmbedderConfig(**kwargs))
237237

238-
def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) -> npt.NDArray[np.float32]:
238+
def embed(
239+
self, utterances: list[str], task_type: TaskTypeEnum | None = None, return_tensors: bool = False
240+
) -> npt.NDArray[np.float32] | torch.Tensor:
239241
"""Calculate embeddings for a list of utterances.
240242
241243
Args:
242244
utterances: List of input texts to calculate embeddings for.
243245
task_type: Type of task for which embeddings are calculated.
246+
return_tensors: If True, return a PyTorch tensor; otherwise, return a numpy array.
244247
245248
Returns:
246-
A numpy array of embeddings.
249+
A numpy array or PyTorch tensor of embeddings.
247250
"""
248251
if len(utterances) == 0:
249252
msg = "Empty input"
@@ -263,7 +266,10 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
263266
embeddings_path = _get_embeddings_path(hasher.hexdigest())
264267
if embeddings_path.exists():
265268
logger.debug("loading embeddings from %s", str(embeddings_path))
266-
return np.load(embeddings_path) # type: ignore[no-any-return]
269+
embeddings_np = np.load(embeddings_path)
270+
if return_tensors:
271+
return torch.from_numpy(embeddings_np).to(self.config.device)
272+
return embeddings_np # type: ignore[no-any-return]
267273

268274
self._model = self._load_model()
269275

@@ -281,15 +287,19 @@ def embed(self, utterances: list[str], task_type: TaskTypeEnum | None = None) ->
281287

282288
embeddings = self._model.encode(
283289
utterances,
284-
convert_to_numpy=True,
290+
convert_to_numpy=not return_tensors,
291+
convert_to_tensor=return_tensors,
285292
batch_size=self.config.batch_size,
286293
normalize_embeddings=True,
287294
prompt=prompt,
288295
)
289296

290297
if self.config.use_cache:
298+
embeddings_to_save = embeddings
299+
if return_tensors:
300+
embeddings_to_save = embeddings.cpu().numpy()
291301
embeddings_path.parent.mkdir(parents=True, exist_ok=True)
292-
np.save(embeddings_path, embeddings)
302+
np.save(embeddings_path, embeddings_to_save)
293303

294304
return embeddings
295305

src/autointent/modules/scoring/_gcn/gcn_scorer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,13 @@ def fit(self, utterances: list[str], labels: ListOfLabels, descriptions: list[st
132132
self._embedder = Embedder(self.embedder_config)
133133
self._label_embedder = Embedder(self.label_embedder_config)
134134

135-
x_tensor = torch.tensor(self._embedder.embed(utterances, TaskTypeEnum.classification))
135+
x_tensor = self._embedder.embed(utterances, TaskTypeEnum.classification, return_tensors=True)
136136
y_tensor_dtype = torch.float if self._multilabel else torch.long
137137
y_tensor = torch.tensor(labels, dtype=y_tensor_dtype)
138138

139-
label_embeddings = torch.tensor(self._label_embedder.embed(descriptions, TaskTypeEnum.classification)).to(
140-
self.torch_config.device
141-
)
139+
label_embeddings = self._label_embedder.embed(
140+
descriptions, TaskTypeEnum.classification, return_tensors=True
141+
).to(self.torch_config.device)
142142

143143
self._model = TextMLGCN(
144144
num_classes=self._n_classes,
@@ -169,7 +169,7 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
169169
if not hasattr(self, "_model"):
170170
msg = "Model is not trained. Call fit() first."
171171
raise RuntimeError(msg)
172-
x_tensor = torch.tensor(self._embedder.embed(utterances, TaskTypeEnum.classification))
172+
x_tensor = self._embedder.embed(utterances, TaskTypeEnum.classification, return_tensors=True)
173173
return self._predict_tensors(x_tensor)
174174

175175
def clear_cache(self) -> None:

0 commit comments

Comments
 (0)