File tree Expand file tree Collapse file tree 1 file changed +14
-9
lines changed
autointent/modules/scoring Expand file tree Collapse file tree 1 file changed +14
-9
lines changed Original file line number Diff line number Diff line change 33import tempfile
44from typing import Any
55
6+ import numpy as np
67import numpy .typing as npt
78import torch
89from datasets import Dataset
@@ -126,15 +127,19 @@ def predict(self, utterances: list[str]) -> npt.NDArray[Any]:
126127 msg = "Model is not trained. Call fit() first."
127128 raise RuntimeError (msg )
128129
129- inputs = self ._tokenizer (utterances , return_tensors = "pt" , ** self .model_config .tokenizer_config .model_dump ())
130-
131- with torch .no_grad ():
132- outputs = self ._model (** inputs )
133- logits = outputs .logits
134-
135- if self ._multilabel :
136- return torch .sigmoid (logits ).numpy ()
137- return torch .softmax (logits , dim = 1 ).numpy ()
130+ all_predictions = []
131+ for i in range (0 , len (utterances ), self .batch_size ):
132+ batch = utterances [i :i + self .batch_size ]
133+ inputs = self ._tokenizer (batch , return_tensors = "pt" , ** self .model_config .tokenizer_config .model_dump ())
134+ with torch .no_grad ():
135+ outputs = self ._model (** inputs )
136+ logits = outputs .logits
137+ if self ._multilabel :
138+ batch_predictions = torch .sigmoid (logits ).numpy ()
139+ else :
140+ batch_predictions = torch .softmax (logits , dim = 1 ).numpy ()
141+ all_predictions .append (batch_predictions )
142+ return np .vstack (all_predictions ) if all_predictions else np .array ([])
138143
139144 def clear_cache (self ) -> None :
140145 if hasattr (self , "_model" ):
You can’t perform that action at this time.
0 commit comments