Skip to content

Commit 30cc0ce

Browse files
committed
batches
1 parent 92b7f61 commit 30cc0ce

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

autointent/modules/scoring/_bert.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import tempfile
44
from typing import Any
55

6+
import numpy as np
67
import numpy.typing as npt
78
import torch
89
from datasets import Dataset
@@ -126,15 +127,19 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
126127
msg = "Model is not trained. Call fit() first."
127128
raise RuntimeError(msg)
128129

129-
inputs = self._tokenizer(utterances, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())
130-
131-
with torch.no_grad():
132-
outputs = self._model(**inputs)
133-
logits = outputs.logits
134-
135-
if self._multilabel:
136-
return torch.sigmoid(logits).numpy()
137-
return torch.softmax(logits, dim=1).numpy()
130+
all_predictions = []
131+
for i in range(0, len(utterances), self.batch_size):
132+
batch = utterances[i:i + self.batch_size]
133+
inputs = self._tokenizer(batch, return_tensors="pt", **self.model_config.tokenizer_config.model_dump())
134+
with torch.no_grad():
135+
outputs = self._model(**inputs)
136+
logits = outputs.logits
137+
if self._multilabel:
138+
batch_predictions = torch.sigmoid(logits).numpy()
139+
else:
140+
batch_predictions = torch.softmax(logits, dim=1).numpy()
141+
all_predictions.append(batch_predictions)
142+
return np.vstack(all_predictions) if all_predictions else np.array([])
138143

139144
def clear_cache(self) -> None:
140145
if hasattr(self, "_model"):

0 commit comments

Comments
 (0)