Skip to content

Commit b7c1bdc

Browse files
emredsjjmachan
andauthored
Bugfix Summarization Score (#1164)
Hi! I was experimenting with Summarization Score and found an issue that when there are no questions generated by the LLM, it throws division by zero error due to 0 generated answers. ## Notable Changes - Added an assertion to keyphrases, questions and answers generation functions. - Changed the llm response variable name `answer` to `response` to prevent possible confusion with score related `answer`. Please let me know if there is something I can add. Cheers! --------- Co-authored-by: jjmachan <[email protected]>
1 parent fdba9a2 commit b7c1bdc

File tree

7 files changed

+68
-51
lines changed

7 files changed

+68
-51
lines changed

docs/howtos/customisations/run_config.ipynb

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,15 +53,16 @@
5353
"\n",
5454
"# load the dataset\n",
5555
"from datasets import load_dataset\n",
56+
"\n",
5657
"amnesty_qa = load_dataset(\"explodinggradients/amnesty_qa\", \"english_v2\")\n",
5758
"\n",
5859
"# configure RunConfig\n",
5960
"from ragas.run_config import RunConfig\n",
6061
"\n",
6162
"_ = evaluate(\n",
62-
" dataset=amnesty_qa[\"eval\"], \n",
63+
" dataset=amnesty_qa[\"eval\"],\n",
6364
" metrics=[faithfulness],\n",
64-
" run_config=RunConfig(max_workers=64), # increasing max_workers from default 16\n",
65+
" run_config=RunConfig(max_workers=64), # increasing max_workers from default 16\n",
6566
")"
6667
]
6768
},
@@ -94,9 +95,9 @@
9495
],
9596
"source": [
9697
"_ = evaluate(\n",
97-
" dataset=amnesty_qa[\"eval\"], \n",
98+
" dataset=amnesty_qa[\"eval\"],\n",
9899
" metrics=[faithfulness],\n",
99-
" run_config=RunConfig(max_workers=2), # increasing max_workers from default 16\n",
100+
" run_config=RunConfig(max_workers=2), # increasing max_workers from default 16\n",
100101
")"
101102
]
102103
},

src/ragas/metrics/_summarization.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def _compute_score(self, scores) -> float:
186186
"""Returns average score of the different scores."""
187187
return sum(scores) / len(scores)
188188

189-
def _compute_qa_score(self, answers: t.List) -> float:
189+
def _compute_qa_score(self, answers: t.List[str]) -> float:
190190
"""Returns a score between 0 and 1 reflecting the fraction of
191191
correct answers, ie with a value 'yes'
192192
"""
@@ -209,10 +209,15 @@ async def _extract_keyphrases(self, text: str, callbacks: Callbacks) -> t.List[s
209209
callbacks=callbacks,
210210
)
211211
result_text = result.generations[0][0].text
212-
answer = await _output_parser_keyphrase_extraction.aparse(
212+
response = await _output_parser_keyphrase_extraction.aparse(
213213
result_text, p_value, self.llm, self.max_retries
214214
)
215-
return answer.keyphrases if answer else []
215+
216+
if not response or not response.keyphrases:
217+
logging.error("No keyphrases generated, unable to calculate the score.")
218+
return []
219+
220+
return response.keyphrases
216221

217222
async def _get_questions(
218223
self, text: str, keyphrases: list[str], callbacks: Callbacks
@@ -225,13 +230,15 @@ async def _get_questions(
225230
)
226231

227232
result_text = result.generations[0][0].text
228-
answer = await _output_parser_question_generation.aparse(
233+
response = await _output_parser_question_generation.aparse(
229234
result_text, p_value, self.llm, self.max_retries
230235
)
231-
if answer is None:
236+
237+
if not response or not response.questions:
238+
logging.error("No questions generated, unable to calculate the score.")
232239
return []
233240

234-
return answer.questions
241+
return response.questions
235242

236243
async def _get_answers(
237244
self, questions: t.List[str], summary: str, callbacks: Callbacks
@@ -244,13 +251,15 @@ async def _get_answers(
244251
)
245252

246253
result_text = result.generations[0][0].text
247-
answer = await _output_parser_answer_generation.aparse(
254+
response = await _output_parser_answer_generation.aparse(
248255
result_text, p_value, self.llm, self.max_retries
249256
)
250-
if answer is None:
257+
258+
if not response or not response.answers:
259+
logger.error("No answers generated, unable to calculate the score.")
251260
return []
252261

253-
return answer.answers
262+
return response.answers
254263

255264

256265
def adapt(self, language: str, cache_dir: str | None = None) -> None:

src/ragas/metrics/base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,13 @@ def get_required_columns(
6363
class Metric(ABC):
6464
@property
6565
@abstractmethod
66-
def name(self) -> str: ...
66+
def name(self) -> str:
67+
...
6768

6869
@property
6970
@abstractmethod
70-
def evaluation_mode(self) -> EvaluationMode: ...
71+
def evaluation_mode(self) -> EvaluationMode:
72+
...
7173

7274
@abstractmethod
7375
def init(self, run_config: RunConfig):
@@ -130,7 +132,8 @@ async def ascore(
130132
return score
131133

132134
@abstractmethod
133-
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: ...
135+
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
136+
...
134137

135138

136139
@dataclass

src/ragas/testset/prompts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,5 +509,5 @@ class EvolutionElimination(BaseModel):
509509
question_rewrite_prompt,
510510
context_scoring_prompt,
511511
filter_question_prompt,
512-
evolution_elimination_prompt
513-
]
512+
evolution_elimination_prompt,
513+
]

tests/unit/test_analytics.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
2+
23
import typing as t
4+
35
import pytest
46

57

@@ -130,7 +132,7 @@ def test_testset_generation_tracking(monkeypatch):
130132

131133

132134
def test_was_completed(monkeypatch):
133-
from ragas._analytics import track_was_completed, IsCompleteEvent
135+
from ragas._analytics import IsCompleteEvent, track_was_completed
134136

135137
event_properties_list: t.List[IsCompleteEvent] = []
136138

tests/unit/test_executor_in_jupyter.ipynb

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
"\n",
3434
"exec = Executor(raise_exceptions=True)\n",
3535
"for i in range(10):\n",
36-
" exec.submit(sleep, i/10)\n",
36+
" exec.submit(sleep, i / 10)\n",
3737
"\n",
3838
"assert exec.results(), \"didn't get anything from results\""
3939
]
@@ -140,16 +140,18 @@
140140
"source": [
141141
"from ragas.metrics.base import Metric, EvaluationMode\n",
142142
"\n",
143+
"\n",
143144
"class FakeMetric(Metric):\n",
144145
" name = \"fake_metric\"\n",
145146
" evaluation_mode = EvaluationMode.qa\n",
146147
"\n",
147148
" def init(self):\n",
148149
" pass\n",
149150
"\n",
150-
" async def _ascore(self, row, callbacks)->float:\n",
151+
" async def _ascore(self, row, callbacks) -> float:\n",
151152
" return 0\n",
152153
"\n",
154+
"\n",
153155
"fm = FakeMetric()"
154156
]
155157
},

tests/unit/test_run_config.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,42 @@
1-
import sys, importlib
2-
from packaging.version import parse as parse_version
1+
import importlib
2+
import sys
33
from platform import python_version
4+
45
import pytest
5-
from numpy.random import default_rng, Generator
6+
from numpy.random import Generator, default_rng
7+
from packaging.version import parse as parse_version
68

79
from ragas.run_config import RunConfig
810

911
if parse_version(python_version()) < parse_version("3.10"):
10-
from typing import NewType, Callable
11-
RandomComparison = NewType("RandomComparison", Callable[[Generator, Generator], bool])
12+
from typing import Callable, NewType
13+
14+
RandomComparison = NewType(
15+
"RandomComparison", Callable[[Generator, Generator], bool]
16+
)
1217
elif parse_version(python_version()) >= parse_version("3.10"):
13-
from typing import TypeAlias, Callable
18+
from typing import Callable, TypeAlias
19+
1420
RandomComparison: TypeAlias = Callable[[Generator, Generator], bool]
1521

22+
1623
@pytest.fixture(scope="function")
1724
def compare_rng() -> Callable[[Generator, Generator], bool]:
18-
"""Pytest fixture wrapper to check :py:cls:`numpy.random.Generator` object equivalence.
25+
"""Pytest fixture wrapper to check :py:cls:`numpy.random.Generator` object equivalence."""
1926

20-
"""
21-
def _compare_rng(rng_0:Generator, rng_1:Generator) -> bool:
27+
def _compare_rng(rng_0: Generator, rng_1: Generator) -> bool:
2228
"""Compare two :py:cls:`numpy.random.Generator`object.
23-
29+
2430
Args:
2531
rng_0 (numpy.random.Generator) : The first generator to compare with.
2632
rng_1 (numpy.random.Generator) : The second generator to compare with.
2733
2834
Returns:
2935
bool: Whether the two generators are at the same state.
30-
36+
3137
"""
3238
return rng_0.random() == rng_1.random()
33-
39+
3440
return _compare_rng
3541

3642

@@ -39,9 +45,11 @@ def _compare_rng(rng_0:Generator, rng_1:Generator) -> bool:
3945
(
4046
[42, True],
4147
[None, False],
42-
)
48+
),
4349
)
44-
def test_random_num_generator(seed, compare_rng:RandomComparison, expected_equivalence):
50+
def test_random_num_generator(
51+
seed, compare_rng: RandomComparison, expected_equivalence
52+
):
4553
"""Check :py:mod:`numpy.random` functionality and seed behaviour control."""
4654
rc = RunConfig(seed=seed)
4755

@@ -53,7 +61,7 @@ def test_random_num_generator(seed, compare_rng:RandomComparison, expected_equiv
5361
assert compare_rng(rc.rng, rng) == expected_equivalence
5462

5563
# Check generation consistency
56-
importlib.reload(sys.modules['numpy.random'])
64+
importlib.reload(sys.modules["numpy.random"])
5765
new_rc = RunConfig(seed=seed)
5866
new_rng = default_rng(seed=seed)
5967

@@ -63,22 +71,14 @@ def test_random_num_generator(seed, compare_rng:RandomComparison, expected_equiv
6371

6472
# Check equivalence
6573
if expected_equivalence:
66-
assert all(
67-
list(
68-
map(
69-
compare_rng,
70-
[rc.rng, new_rc.rng],
71-
[new_rng, rng]
72-
)
73-
)
74-
)
74+
assert all(list(map(compare_rng, [rc.rng, new_rc.rng], [new_rng, rng])))
7575
else:
7676
assert all(
77-
list(
78-
map(
79-
lambda x, y:not compare_rng(x, y),
80-
[rc.rng, new_rc.rng],
81-
[new_rng, rng]
82-
)
77+
list(
78+
map(
79+
lambda x, y: not compare_rng(x, y),
80+
[rc.rng, new_rc.rng],
81+
[new_rng, rng],
8382
)
8483
)
84+
)

0 commit comments

Comments
 (0)