Skip to content

Commit 99d1a47

Browse files
authored
feat: train your own evaluators (#1701)
- [x] Config - [x] Loss - [x] Optimizer base
1 parent 643ab66 commit 99d1a47

File tree

7 files changed

+208
-0
lines changed

7 files changed

+208
-0
lines changed

src/ragas/config.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import typing as t
2+
3+
from pydantic import BaseModel, Field
4+
5+
from ragas.embeddings import BaseRagasEmbeddings
6+
from ragas.llms import BaseRagasLLM
7+
from ragas.losses import Loss
8+
from ragas.optimizers import Optimizer
9+
10+
DEFAULT_OPTIMIZER_CONFIG = {"max_steps": 100}
11+
12+
13+
class DemonstrationConfig(BaseModel):
14+
enabled: bool = True
15+
top_k: int = 3
16+
technique: t.Literal["random", "similarity"] = "similarity"
17+
embedding: t.Optional[BaseRagasEmbeddings] = None
18+
19+
20+
class InstructionConfig(BaseModel):
21+
enabled: bool = True
22+
loss: t.Optional[Loss] = None
23+
optimizer: Optimizer
24+
optimizer_config: t.Dict[str, t.Any] = Field(
25+
default_factory=lambda: DEFAULT_OPTIMIZER_CONFIG
26+
)
27+
llm: t.Optional[BaseRagasLLM] = None

src/ragas/dataset_schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ def upload(self, base_url: str = RAGAS_API_URL, verbose: bool = True) -> str:
531531
return evaluation_endpoint
532532

533533

534+
534535
class PromptAnnotation(BaseModel):
535536
prompt_input: t.Dict[str, t.Any]
536537
prompt_output: t.Dict[str, t.Any]

src/ragas/embeddings/base.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010
from langchain_core.embeddings import Embeddings
1111
from langchain_openai.embeddings import OpenAIEmbeddings
1212
from pydantic.dataclasses import dataclass
13+
from pydantic_core import CoreSchema, core_schema
1314

1415
from ragas.run_config import RunConfig, add_async_retry, add_retry
1516

1617
if t.TYPE_CHECKING:
1718
from llama_index.core.base.embeddings.base import BaseEmbedding
19+
from pydantic import GetCoreSchemaHandler
20+
1821

1922
DEFAULT_MODEL_NAME = "BAAI/bge-small-en-v1.5"
2023

@@ -64,6 +67,17 @@ def set_run_config(self, run_config: RunConfig):
6467
"""
6568
self.run_config = run_config
6669

70+
@classmethod
71+
def __get_pydantic_core_schema__(
72+
cls, source_type: t.Any, handler: GetCoreSchemaHandler
73+
) -> CoreSchema:
74+
"""
75+
Define how Pydantic generates a schema for BaseRagasEmbeddings.
76+
"""
77+
return core_schema.no_info_after_validator_function(
78+
cls, core_schema.is_instance_schema(cls) # The validator function
79+
)
80+
6781

6882
class LangchainEmbeddingsWrapper(BaseRagasEmbeddings):
6983
"""

src/ragas/losses.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import typing as t
2+
from abc import ABC, abstractmethod
3+
4+
5+
class Loss(ABC):
6+
"""
7+
Abstract base class for all loss functions.
8+
"""
9+
10+
@abstractmethod
11+
def __call__(self, predicted: t.List, actual: t.List) -> float:
12+
raise NotImplementedError
13+
14+
15+
class MSELoss(Loss):
16+
"""
17+
Mean Squared Error loss function.
18+
"""
19+
20+
reduction: t.Literal["mean", "sum"] = "mean"
21+
22+
def __call__(self, predicted: t.List[float], actual: t.List[float]) -> float:
23+
24+
errors = [(p - a) ** 2 for p, a in zip(predicted, actual)]
25+
if self.reduction == "mean":
26+
return sum(errors) / len(errors)
27+
elif self.reduction == "sum":
28+
return sum(errors)
29+
else:
30+
raise ValueError(f"Invalid reduction method: {self.reduction}")
31+
32+
33+
class BinaryMetricLoss(Loss):
34+
"""
35+
Computes the loss for binary metrics.
36+
Supports accuracy and F1-score.
37+
"""
38+
39+
metric: t.Literal["accuracy", "f1_score"] = "accuracy"
40+
41+
def __call__(self, predicted: t.List[int], actual: t.List[int]) -> float:
42+
"""
43+
Computes the loss using the specified reduction.
44+
45+
Parameters
46+
----------
47+
predicted : list[int]
48+
List of predicted binary values (0 or 1).
49+
actual : list[int]
50+
List of actual binary values (0 or 1).
51+
52+
Returns
53+
-------
54+
float
55+
The computed loss based on the reduction type.
56+
"""
57+
if len(predicted) != len(actual):
58+
raise ValueError("Predicted and actual lists must have the same length.")
59+
60+
if self.metric == "accuracy":
61+
return self._accuracy(predicted, actual)
62+
elif self.metric == "f1_score":
63+
return self._f1_score(predicted, actual)
64+
else:
65+
raise ValueError(f"Unsupported reduction type: {self.metric}")
66+
67+
def _accuracy(self, predicted: list[int], actual: t.List[int]) -> float:
68+
"""
69+
Computes accuracy as the reduction operation.
70+
71+
Returns
72+
-------
73+
float
74+
Accuracy (proportion of correct predictions).
75+
"""
76+
correct = sum(p == a for p, a in zip(predicted, actual))
77+
return correct / len(actual)
78+
79+
def _f1_score(self, predicted: t.List[int], actual: t.List[int]) -> float:
80+
"""
81+
Computes F1-score as the reduction operation.
82+
83+
Returns
84+
-------
85+
float
86+
The F1-score.
87+
"""
88+
tp = sum(p == 1 and a == 1 for p, a in zip(predicted, actual))
89+
fp = sum(p == 1 and a == 0 for p, a in zip(predicted, actual))
90+
fn = sum(p == 0 and a == 1 for p, a in zip(predicted, actual))
91+
92+
precision = tp / (tp + fp) if tp + fp > 0 else 0
93+
recall = tp / (tp + fn) if tp + fn > 0 else 0
94+
f1 = (
95+
(2 * precision * recall) / (precision + recall)
96+
if precision + recall > 0
97+
else 0
98+
)
99+
return f1

src/ragas/metrics/base.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,10 @@
2626
if t.TYPE_CHECKING:
2727
from langchain_core.callbacks import Callbacks
2828

29+
from ragas.config import DemonstrationConfig, InstructionConfig
2930
from ragas.embeddings import BaseRagasEmbeddings
3031
from ragas.llms import BaseRagasLLM
32+
3133
logger = logging.getLogger(__name__)
3234

3335

@@ -227,6 +229,16 @@ def init(self, run_config: RunConfig):
227229
)
228230
self.llm.set_run_config(run_config)
229231

232+
def train(
233+
self,
234+
path: str,
235+
demonstration_config: DemonstrationConfig,
236+
instruction_config: InstructionConfig,
237+
callbacks: Callbacks,
238+
) -> None:
239+
240+
raise NotImplementedError("Training is not implemented for this metric.")
241+
230242

231243
@dataclass
232244
class MetricWithEmbeddings(Metric):

src/ragas/optimizers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .base import Optimizer
2+
3+
__all__ = ["Optimizer"]

src/ragas/optimizers/base.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import typing as t
2+
from abc import ABC, abstractmethod
3+
from dataclasses import dataclass
4+
5+
from langchain_core.callbacks import Callbacks
6+
7+
from ragas.dataset_schema import SingleMetricAnnotation
8+
from ragas.llms.base import BaseRagasLLM
9+
from ragas.losses import Loss
10+
from ragas.metrics.base import MetricWithLLM
11+
from ragas.run_config import RunConfig
12+
13+
14+
@dataclass
15+
class Optimizer(ABC):
16+
"""
17+
Abstract base class for all optimizers.
18+
"""
19+
20+
metric: t.Optional[MetricWithLLM] = None
21+
llm: t.Optional[BaseRagasLLM] = None
22+
23+
@abstractmethod
24+
def optimize(
25+
self,
26+
dataset: SingleMetricAnnotation,
27+
loss: Loss,
28+
config: t.Dict[t.Any, t.Any],
29+
run_config: t.Optional[RunConfig] = None,
30+
batch_size: t.Optional[int] = None,
31+
callbacks: t.Optional[Callbacks] = None,
32+
with_debugging_logs=False,
33+
raise_exceptions: bool = True,
34+
) -> t.Dict[str, str]:
35+
"""
36+
Optimizes the prompts for the given metric.
37+
38+
Parameters
39+
----------
40+
metric : MetricWithLLM
41+
The metric to optimize.
42+
train_data : Any
43+
The training data.
44+
config : InstructionConfig
45+
The training configuration.
46+
47+
Returns
48+
-------
49+
Dict[str, str]
50+
The optimized prompts for given chain.
51+
"""
52+
pass

0 commit comments

Comments
 (0)