Skip to content

Commit e5e4543

Browse files
authored
fix: sentence segmenter for multi languages (#946)
#947 There was a bug when using pysdb sentence segmenter with multiple languages other than English, this PR fixes it.
1 parent 2d79365 commit e5e4543

File tree

3 files changed

+31
-9
lines changed

3 files changed

+31
-9
lines changed

src/ragas/metrics/_answer_correctness.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import numpy as np
88
from langchain_core.pydantic_v1 import BaseModel
9-
from pysbd import Segmenter
109

1110
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
1211
from ragas.llms.prompt import Prompt, PromptValue
@@ -16,7 +15,12 @@
1615
HasSegmentMethod,
1716
_statements_output_parser,
1817
)
19-
from ragas.metrics.base import EvaluationMode, MetricWithEmbeddings, MetricWithLLM
18+
from ragas.metrics.base import (
19+
EvaluationMode,
20+
MetricWithEmbeddings,
21+
MetricWithLLM,
22+
get_segmenter,
23+
)
2024
from ragas.run_config import RunConfig
2125

2226
if t.TYPE_CHECKING:
@@ -176,7 +180,7 @@ def __post_init__(self: t.Self):
176180

177181
if self.sentence_segmenter is None:
178182
language = self.long_form_answer_prompt.language
179-
self.sentence_segmenter = Segmenter(language=language, clean=False)
183+
self.sentence_segmenter = get_segmenter(language=language, clean=False)
180184

181185
def init(self, run_config: RunConfig):
182186
super().init(run_config)

src/ragas/metrics/_faithfulness.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,10 @@
77

88
import numpy as np
99
from langchain_core.pydantic_v1 import BaseModel, Field
10-
from pysbd import Segmenter
1110

1211
from ragas.llms.output_parser import RagasoutputParser, get_json_format_instructions
1312
from ragas.llms.prompt import Prompt
14-
from ragas.metrics.base import EvaluationMode, MetricWithLLM, ensembler
13+
from ragas.metrics.base import EvaluationMode, MetricWithLLM, ensembler, get_segmenter
1514

1615
if t.TYPE_CHECKING:
1716
from langchain_core.callbacks import Callbacks
@@ -81,7 +80,7 @@ def dicts(self) -> t.List[t.Dict]:
8180
],
8281
input_keys=["question", "answer", "sentences"],
8382
output_key="analysis",
84-
language="en",
83+
language="english",
8584
)
8685

8786

@@ -160,7 +159,7 @@ def dicts(self) -> t.List[t.Dict]:
160159
input_keys=["context", "statements"],
161160
output_key="answer",
162161
output_type="json",
163-
language="en",
162+
language="english",
164163
) # noqa: E501
165164

166165

@@ -190,7 +189,7 @@ def reproducibility(self, value):
190189
def __post_init__(self):
191190
if self.sentence_segmenter is None:
192191
language = self.nli_statements_message.language
193-
self.sentence_segmenter = Segmenter(language=language, clean=False)
192+
self.sentence_segmenter = get_segmenter(language=language, clean=False)
194193

195194
def _create_nli_prompt(self, row: t.Dict, statements: t.List[str]) -> PromptValue:
196195
assert self.llm is not None, "llm must be set to compute score"

src/ragas/metrics/base.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
from ragas.embeddings import BaseRagasEmbeddings
2424
from ragas.llms import BaseRagasLLM
2525

26+
from pysbd import Segmenter
27+
from pysbd.languages import LANGUAGE_CODES
28+
29+
LANGUAGE_CODES = {v.__name__.lower(): k for k, v in LANGUAGE_CODES.items()}
2630

2731
EvaluationMode = Enum("EvaluationMode", "qac qa qc gc ga qga qcg")
2832

@@ -191,5 +195,20 @@ def from_discrete(self, inputs: list[list[t.Dict]], attribute: str):
191195
return verdict_agg
192196

193197

194-
ensembler = Ensember()
198+
def get_segmenter(
199+
language: str = "english", clean: bool = False, char_span: bool = False
200+
):
201+
"""
202+
Get a sentence segmenter for a given language
203+
"""
204+
language = language.lower()
205+
if language not in LANGUAGE_CODES:
206+
raise ValueError(
207+
f"Language '{language}' not supported. Supported languages: {LANGUAGE_CODES.keys()}"
208+
)
209+
return Segmenter(
210+
language=LANGUAGE_CODES[language], clean=clean, char_span=char_span
211+
)
212+
195213

214+
ensembler = Ensember()

0 commit comments

Comments
 (0)