Skip to content

Commit 60b9e7c

Browse files
fix: cleaned up some metrics (#2111)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
1 parent 437925f commit 60b9e7c

File tree

13 files changed

+360
-756
lines changed

13 files changed

+360
-756
lines changed

experimental/ragas_experimental/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def hello_world(
550550
)
551551
552552
553-
@numeric_metric(name="accuracy_score", range=(0, 1))
553+
@numeric_metric(name="accuracy_score", allowed_values=(0, 1))
554554
def accuracy_score(response: str, expected: str):
555555
"""
556556
Is the response a good response to the query?

experimental/ragas_experimental/dataset.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,6 +211,21 @@ def save(self) -> None:
211211
else:
212212
self.backend.save_dataset(self.name, dict_data, data_model=self.data_model)
213213

214+
def reload(self) -> None:
215+
# Backend always returns dicts
216+
# Use the correct backend method based on the class type
217+
if hasattr(self, "DATATABLE_TYPE") and self.DATATABLE_TYPE == "Experiment":
218+
dict_data = self.backend.load_experiment(self.name)
219+
else:
220+
dict_data = self.backend.load_dataset(self.name)
221+
222+
if self.data_model:
223+
# Validated mode - convert dicts to Pydantic models
224+
self._data = [self.data_model(**d) for d in dict_data]
225+
else:
226+
# Unvalidated mode - keep as dicts but wrapped in Dataset API
227+
self._data = dict_data # type: ignore
228+
214229
def validate_with(self, data_model: t.Type[T]) -> Self:
215230
"""Apply validation to an unvalidated dataset"""
216231
if self.data_model is not None:

experimental/ragas_experimental/metric/__init__.py

Lines changed: 0 additions & 13 deletions
This file was deleted.

experimental/ragas_experimental/metric/result.py

Lines changed: 0 additions & 248 deletions
This file was deleted.
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
from .base import Metric
2+
from .discrete import DiscreteMetric, discrete_metric
3+
from .numeric import NumericMetric, numeric_metric
4+
from .ranking import RankingMetric, ranking_metric
5+
from .result import MetricResult
6+
7+
__all__ = [
8+
"MetricResult",
9+
"Metric",
10+
"DiscreteMetric",
11+
"NumericMetric",
12+
"RankingMetric",
13+
"discrete_metric",
14+
"numeric_metric",
15+
"ranking_metric",
16+
]

experimental/ragas_experimental/metric/base.py renamed to experimental/ragas_experimental/metrics/base.py

Lines changed: 45 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,37 @@
2121
from ragas_experimental.dataset import Dataset
2222

2323

24+
@dataclass
25+
class BaseMetric(ABC):
26+
name: str
27+
28+
@abstractmethod
29+
def score(self, **kwargs) -> MetricResult:
30+
pass
31+
32+
@abstractmethod
33+
async def ascore(self, **kwargs) -> MetricResult:
34+
pass
35+
36+
def batch_score(
37+
self,
38+
inputs: t.List[t.Dict[str, t.Any]],
39+
) -> t.List[MetricResult]:
40+
return [self.score(**input_dict) for input_dict in inputs]
41+
42+
async def abatch_score(
43+
self,
44+
inputs: t.List[t.Dict[str, t.Any]],
45+
) -> t.List[MetricResult]:
46+
async_tasks = []
47+
for input_dict in inputs:
48+
# Process input asynchronously
49+
async_tasks.append(self.ascore(**input_dict))
50+
51+
# Run all tasks concurrently and return results
52+
return await asyncio.gather(*async_tasks)
53+
54+
2455
@dataclass
2556
class Metric(ABC):
2657
"""Base class for all metrics in the LLM evaluation library."""
@@ -48,7 +79,12 @@ def get_variables(self) -> t.List[str]:
4879
def score(self, llm: RagasLLM, **kwargs) -> MetricResult:
4980
traces = {}
5081
traces["input"] = kwargs
82+
83+
# get prompt
84+
if not self.prompt:
85+
raise Exception("prompt not passed")
5186
prompt_input = self.prompt.format(**kwargs)
87+
5288
response = llm.generate(prompt_input, response_model=self._response_model)
5389
traces["output"] = response.model_dump()
5490
result = MetricResult(**response.model_dump())
@@ -58,7 +94,11 @@ def score(self, llm: RagasLLM, **kwargs) -> MetricResult:
5894
async def ascore(self, llm: RagasLLM, **kwargs) -> MetricResult:
5995
traces = {}
6096

97+
# get prompt
98+
if not self.prompt:
99+
raise Exception("prompt not passed")
61100
prompt_input = self.prompt.format(**kwargs)
101+
62102
traces["input"] = prompt_input
63103
response = await llm.agenerate(
64104
prompt_input,
@@ -137,11 +177,13 @@ def align(
137177
Align the metric with the specified experiments by different optimization methods.
138178
"""
139179

140-
assert isinstance(self.prompt, Prompt)
180+
# get prompt
181+
if not self.prompt:
182+
raise Exception("prompt not passed")
141183
self.prompt = DynamicFewShotPrompt.from_prompt(
142184
self.prompt, embedding_model, **kwargs
143185
)
144-
dataset.load()
186+
dataset.reload()
145187
total_items = len(dataset)
146188
input_vars = self.get_variables()
147189
output_vars = [self.name, f"{self.name}_reason"]
@@ -188,7 +230,7 @@ def validate_alignment(
188230
for v in self.get_variables()
189231
}
190232
score = self.score(llm=llm, **values)
191-
pred_scores.append(score.result)
233+
pred_scores.append(score.value)
192234

193235
df = test_dataset.to_pandas()
194236
df[f"{self.name}_pred"] = pred_scores

0 commit comments

Comments
 (0)