Skip to content

Commit fa864a6

Browse files
MiaoranmmmMiaoranjjmachan
authored
load HHEM on specified device (#1235)
Allow users to specify the device to load HHEM and add `_create_batch` to avoid OOM --------- Co-authored-by: Miaoran <[email protected]> Co-authored-by: jjmachan <[email protected]>
1 parent d58dc01 commit fa864a6

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

src/ragas/metrics/_faithfulness.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,9 @@
1717

1818
from ragas.llms.prompt import PromptValue
1919

20-
from typing import Any, Protocol
2120

22-
23-
class HasSegmentMethod(Protocol):
24-
def segment(self, text) -> Any:
21+
class HasSegmentMethod(t.Protocol):
22+
def segment(self, text) -> t.Any:
2523
...
2624

2725

@@ -316,6 +314,8 @@ def save(self, cache_dir: t.Optional[str] = None) -> None:
316314
@dataclass
317315
class FaithulnesswithHHEM(Faithfulness):
318316
name: str = "faithfulness_with_hhem" # type: ignore
317+
device: str = "cpu"
318+
batch_size: int = 10
319319

320320
def __post_init__(self):
321321
try:
@@ -327,6 +327,7 @@ def __post_init__(self):
327327
self.nli_classifier = AutoModelForSequenceClassification.from_pretrained(
328328
"vectara/hallucination_evaluation_model", trust_remote_code=True
329329
)
330+
self.nli_classifier.to(self.device)
330331
super().__post_init__()
331332

332333
def _create_pairs(
@@ -339,6 +340,13 @@ def _create_pairs(
339340
pairs = [(premise, statement) for statement in statements]
340341
return pairs
341342

343+
def _create_batch(
344+
self, pairs: t.List[t.Tuple[str, str]]
345+
) -> t.Generator[t.List[t.Tuple[str, str]], None, None]:
346+
length_of_pairs = len(pairs)
347+
for ndx in range(0, length_of_pairs, self.batch_size):
348+
yield pairs[ndx : min(ndx + self.batch_size, length_of_pairs)]
349+
342350
async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
343351
"""
344352
returns the NLI score for each (q, c, a) pair
@@ -362,9 +370,14 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
362370

363371
assert isinstance(statements, t.List), "statements must be a list"
364372

373+
scores = []
365374
pairs = self._create_pairs(row, statements)
366-
scores = self.nli_classifier.predict(pairs).detach().numpy().round()
367-
return scores.sum() / len(scores)
375+
for input_pairs in self._create_batch(pairs): # to avoid OOM
376+
batch_scores = (
377+
self.nli_classifier.predict(input_pairs).cpu().detach().round()
378+
)
379+
scores += batch_scores
380+
return sum(scores) / len(scores)
368381

369382

370383
faithfulness = Faithfulness()

0 commit comments

Comments
 (0)