1717
1818 from ragas .llms .prompt import PromptValue
1919
20- from typing import Any , Protocol
2120
22-
23- class HasSegmentMethod (Protocol ):
24- def segment (self , text ) -> Any :
21+ class HasSegmentMethod (t .Protocol ):
22+ def segment (self , text ) -> t .Any :
2523 ...
2624
2725
@@ -316,6 +314,8 @@ def save(self, cache_dir: t.Optional[str] = None) -> None:
316314@dataclass
317315class FaithulnesswithHHEM (Faithfulness ):
318316 name : str = "faithfulness_with_hhem" # type: ignore
317+ device : str = "cpu"
318+ batch_size : int = 10
319319
320320 def __post_init__ (self ):
321321 try :
@@ -327,6 +327,7 @@ def __post_init__(self):
327327 self .nli_classifier = AutoModelForSequenceClassification .from_pretrained (
328328 "vectara/hallucination_evaluation_model" , trust_remote_code = True
329329 )
330+ self .nli_classifier .to (self .device )
330331 super ().__post_init__ ()
331332
332333 def _create_pairs (
@@ -339,6 +340,13 @@ def _create_pairs(
339340 pairs = [(premise , statement ) for statement in statements ]
340341 return pairs
341342
343+ def _create_batch (
344+ self , pairs : t .List [t .Tuple [str , str ]]
345+ ) -> t .Generator [t .List [t .Tuple [str , str ]], None , None ]:
346+ length_of_pairs = len (pairs )
347+ for ndx in range (0 , length_of_pairs , self .batch_size ):
348+ yield pairs [ndx : min (ndx + self .batch_size , length_of_pairs )]
349+
342350 async def _ascore (self : t .Self , row : t .Dict , callbacks : Callbacks ) -> float :
343351 """
344352 returns the NLI score for each (q, c, a) pair
@@ -362,9 +370,14 @@ async def _ascore(self: t.Self, row: t.Dict, callbacks: Callbacks) -> float:
362370
363371 assert isinstance (statements , t .List ), "statements must be a list"
364372
373+ scores = []
365374 pairs = self ._create_pairs (row , statements )
366- scores = self .nli_classifier .predict (pairs ).detach ().numpy ().round ()
367- return scores .sum () / len (scores )
375+ for input_pairs in self ._create_batch (pairs ): # to avoid OOM
376+ batch_scores = (
377+ self .nli_classifier .predict (input_pairs ).cpu ().detach ().round ()
378+ )
379+ scores += batch_scores
380+ return sum (scores ) / len (scores )
368381
369382
370383faithfulness = Faithfulness ()
0 commit comments