Skip to content

Commit 5cbf83e

Browse files
committed
bug fix
1 parent 903dfa7 commit 5cbf83e

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

autointent/_ranker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Can be used to rank retrieved sentences by meaning closeness to provided utterance.
44
"""
55

6+
import gc
67
import itertools as it
78
import json
89
import logging
@@ -274,6 +275,7 @@ def load(cls, path: Path) -> "Ranker":
274275
return cls(**metadata, classifier_head=clf)
275276

276277
def clear_ram(self) -> None:
277-
self.cross_encoder.cpu()
278+
self.cross_encoder.model.cpu()
278279
del self.cross_encoder
280+
gc.collect()
279281
torch.cuda.empty_cache()

autointent/modules/scoring/_knn/knn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,15 @@ def get_embedder_name(self) -> str:
124124
"""
125125
return self.embedder_name
126126

127-
def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
127+
def fit(self, utterances: list[str], labels: ListOfLabels, clear_cache: bool = False) -> None:
128128
"""
129129
Fit the scorer by training or loading the vector index.
130130
131131
:param utterances: List of training utterances.
132132
:param labels: List of labels corresponding to the utterances.
133133
:raises ValueError: If the vector index mismatches the provided utterances.
134134
"""
135-
if hasattr(self, "_vector_index"):
135+
if hasattr(self, "_vector_index") and clear_cache:
136136
self.clear_cache()
137137

138138
self._validate_task(labels)

autointent/modules/scoring/_knn/rerank_scorer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,10 +139,11 @@ def fit(self, utterances: list[str], labels: ListOfLabels) -> None:
139139
)
140140
self._scorer.fit(utterances, labels)
141141

142-
super().fit(utterances, labels)
142+
super().fit(utterances, labels, clear_cache=False)
143143

144144
def clear_cache(self) -> None:
145145
self._scorer.clear_ram()
146+
super().clear_cache()
146147

147148
def _predict(self, utterances: list[str]) -> tuple[npt.NDArray[Any], list[list[str]]]:
148149
"""

0 commit comments

Comments
 (0)