@@ -223,11 +223,7 @@ async def _create_statements(
223223
224224 text , question = row ["response" ], row ["user_input" ]
225225 sentences = self .sentence_segmenter .segment (text )
226- sentences_with_index = {
227- i : sentence
228- for i , sentence in enumerate (sentences )
229- if sentence .strip ().endswith (("." , "。" , "!" , "!" ))
230- }
226+ sentences_with_index = {i : sentence for i , sentence in enumerate (sentences )}
231227
232228 statements_simplified = await self .statement_prompt .generate (
233229 llm = self .llm ,
@@ -320,7 +316,7 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
320316 assert self .llm is not None , "LLM is not set"
321317
322318 statements_simplified = await self ._create_statements (row , callbacks )
323- if statements_simplified is None :
319+ if len ( statements_simplified . sentences ) == 0 :
324320 return np .nan
325321
326322 statements = []
@@ -334,7 +330,9 @@ async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
334330 batch_scores = (
335331 self .nli_classifier .predict (input_pairs ).cpu ().detach ().round ()
336332 )
337- scores += batch_scores
333+ # convert tensor to list of floats
334+ scores .extend (batch_scores .tolist ())
335+
338336 return sum (scores ) / len (scores )
339337
340338
0 commit comments