Skip to content

Commit f3a800d

Browse files
authored
Fix ensembler : handle edge cases (#972)
1 parent 6cf7773 commit f3a800d

File tree

4 files changed

+29
-6
lines changed

4 files changed

+29
-6
lines changed

src/ragas/metrics/_context_precision.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,11 @@ def reproducibility(self, value):
101101
if value < 1:
102102
logger.warning("reproducibility cannot be less than 1, setting to 1")
103103
value = 1
104+
elif value % 2 == 0:
105+
logger.warning(
106+
"reproducibility level cannot be set to even number, setting to odd"
107+
)
108+
value += 1
104109
self._reproducibility = value
105110

106111
def _get_row_attributes(self, row: t.Dict) -> t.Tuple[str, t.List[str], t.Any]:

src/ragas/metrics/_context_recall.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ def reproducibility(self, value):
134134
if value < 1:
135135
logger.warning("reproducibility cannot be less than 1, setting to 1")
136136
value = 1
137+
elif value % 2 == 0:
138+
logger.warning(
139+
"reproducibility level cannot be set to even number, setting to odd"
140+
)
141+
value += 1
137142
self._reproducibility = value
138143

139144
def __post_init__(self) -> None:

src/ragas/metrics/_faithfulness.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,11 @@ def reproducibility(self, value):
184184
if value < 1:
185185
logger.warning("reproducibility cannot be less than 1, setting to 1")
186186
value = 1
187+
elif value % 2 == 0:
188+
logger.warning(
189+
"reproducibility level cannot be set to even number, setting to odd"
190+
)
191+
value += 1
187192
self._reproducibility = value
188193

189194
def __post_init__(self):

src/ragas/metrics/base.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from __future__ import annotations
99

1010
import asyncio
11+
import logging
1112
import typing as t
1213
from abc import ABC, abstractmethod
1314
from collections import Counter
@@ -26,6 +27,9 @@
2627
from pysbd import Segmenter
2728
from pysbd.languages import LANGUAGE_CODES
2829

30+
logger = logging.getLogger(__name__)
31+
32+
2933
LANGUAGE_CODES = {v.__name__.lower(): k for k, v in LANGUAGE_CODES.items()}
3034

3135
EvaluationMode = Enum("EvaluationMode", "qac qa qc gc ga qga qcg")
@@ -173,13 +177,17 @@ def from_discrete(self, inputs: list[list[t.Dict]], attribute: str):
173177
Simple majority voting for binary values, ie [0,0,1] -> 0
174178
inputs: list of list of dicts each containing verdict for a single input
175179
"""
176-
assert all(
177-
len(item) == len(inputs[0]) for item in inputs
178-
), "all inputs must have the same length"
179180

180-
assert all(
181-
attribute in item for input in inputs for item in input
182-
), "attribute not found in all items"
181+
if not isinstance(inputs, list):
182+
inputs = [inputs]
183+
184+
if not all(len(item) == len(inputs[0]) for item in inputs):
185+
logger.warning("All inputs must have the same length")
186+
return inputs[0]
187+
188+
if not all(attribute in item for input in inputs for item in input):
189+
logger.warning(f"All inputs must have {attribute} attribute")
190+
return inputs[0]
183191

184192
if len(inputs) == 1:
185193
return inputs[0]

0 commit comments

Comments
 (0)