Skip to content

Commit 8307f66

Browse files
authored
feat: genetic algo based optimizer (#1724)
```python metric =AspectCritic(name="answer_correctness",definition="Given the user_input, reference and response. Is the response correct compared with the reference",llm=llm_4o) metric.train("alignment_sample.json") ``` [dummy data](https://github.com/user-attachments/files/17997460/alignment_sample.json)
1 parent e2cb28e commit 8307f66

File tree

11 files changed

+922
-29
lines changed

11 files changed

+922
-29
lines changed

src/ragas/callbacks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,15 @@ def __str__(self):
133133

134134
def parse_run_traces(
135135
traces: t.Dict[str, ChainRun],
136+
parent_run_id: t.Optional[str] = None,
136137
) -> t.List[t.Dict[str, t.Any]]:
138+
137139
root_traces = [
138140
chain_trace
139141
for chain_trace in traces.values()
140-
if chain_trace.parent_run_id is None
142+
if chain_trace.parent_run_id == parent_run_id
141143
]
144+
142145
if len(root_traces) > 1:
143146
raise ValueError(
144147
"Multiple root traces found! This is a bug on our end, please file an issue and we will fix it ASAP :)"
@@ -159,7 +162,7 @@ def parse_run_traces(
159162
prompt_traces = {}
160163
for i, prompt_uuid in enumerate(metric_trace.children):
161164
prompt_trace = traces[prompt_uuid]
162-
prompt_traces[f"{i}_{prompt_trace.name}"] = {
165+
prompt_traces[f"{prompt_trace.name}"] = {
163166
"input": prompt_trace.inputs.get("data", {}),
164167
"output": prompt_trace.outputs.get("output", {}),
165168
}

src/ragas/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ragas.embeddings import BaseRagasEmbeddings
66
from ragas.llms import BaseRagasLLM
77
from ragas.losses import Loss
8-
from ragas.optimizers import Optimizer
8+
from ragas.optimizers import GeneticOptimizer, Optimizer
99

1010
DEFAULT_OPTIMIZER_CONFIG = {"max_steps": 100}
1111

@@ -20,7 +20,7 @@ class DemonstrationConfig(BaseModel):
2020
class InstructionConfig(BaseModel):
2121
enabled: bool = True
2222
loss: t.Optional[Loss] = None
23-
optimizer: Optimizer
23+
optimizer: Optimizer = GeneticOptimizer()
2424
optimizer_config: t.Dict[str, t.Any] = Field(
2525
default_factory=lambda: DEFAULT_OPTIMIZER_CONFIG
2626
)

src/ragas/dataset_schema.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from abc import ABC, abstractmethod
77
from collections import defaultdict
88
from dataclasses import dataclass, field
9+
from uuid import UUID
910

1011
import numpy as np
1112
from datasets import Dataset as HFDataset
@@ -43,6 +44,13 @@ def get_features(self) -> t.List[str]:
4344
"""
4445
return list(self.to_dict().keys())
4546

47+
def to_string(self) -> str:
48+
"""
49+
Get the string representation of the sample.
50+
"""
51+
sample_dict = self.to_dict()
52+
return "".join(f"\n{key}:\n\t{val}\n" for key, val in sample_dict.items())
53+
4654

4755
class SingleTurnSample(BaseSample):
4856
"""
@@ -378,6 +386,7 @@ class EvaluationResult:
378386
cost_cb: t.Optional[CostCallbackHandler] = None
379387
traces: t.List[t.Dict[str, t.Any]] = field(default_factory=list)
380388
ragas_traces: t.Dict[str, ChainRun] = field(default_factory=dict, repr=False)
389+
run_id: t.Optional[UUID] = None
381390

382391
def __post_init__(self):
383392
# transform scores from list of dicts to dict of lists
@@ -395,7 +404,8 @@ def __post_init__(self):
395404
values.append(value + 1e-10)
396405

397406
# parse the traces
398-
self.traces = parse_run_traces(self.ragas_traces)
407+
run_id = str(self.run_id) if self.run_id is not None else None
408+
self.traces = parse_run_traces(self.ragas_traces, run_id)
399409

400410
def __repr__(self) -> str:
401411
score_strs = [f"'{k}': {v:0.4f}" for k, v in self._repr_dict.items()]
@@ -531,7 +541,6 @@ def upload(self, base_url: str = RAGAS_API_URL, verbose: bool = True) -> str:
531541
return evaluation_endpoint
532542

533543

534-
535544
class PromptAnnotation(BaseModel):
536545
prompt_input: t.Dict[str, t.Any]
537546
prompt_output: t.Dict[str, t.Any]

src/ragas/evaluation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

33
import typing as t
4+
from uuid import UUID
45

56
from datasets import Dataset
67
from langchain_core.callbacks import BaseCallbackHandler, BaseCallbackManager
78
from langchain_core.embeddings import Embeddings as LangchainEmbeddings
89
from langchain_core.language_models import BaseLanguageModel as LangchainLLM
10+
from tqdm.auto import tqdm
911

1012
from ragas._analytics import track_was_completed
1113
from ragas.callbacks import ChainType, RagasTracer, new_group
@@ -59,12 +61,14 @@ def evaluate(
5961
embeddings: t.Optional[BaseRagasEmbeddings | LangchainEmbeddings] = None,
6062
callbacks: Callbacks = None,
6163
in_ci: bool = False,
62-
run_config: RunConfig = RunConfig(),
64+
run_config: t.Optional[RunConfig] = None,
6365
token_usage_parser: t.Optional[TokenUsageParser] = None,
6466
raise_exceptions: bool = False,
6567
column_map: t.Optional[t.Dict[str, str]] = None,
6668
show_progress: bool = True,
6769
batch_size: t.Optional[int] = None,
70+
_run_id: t.Optional[UUID] = None,
71+
_pbar: t.Optional[tqdm] = None,
6872
) -> EvaluationResult:
6973
"""
7074
Run the evaluation on the dataset with different metrics
@@ -146,6 +150,7 @@ def evaluate(
146150
"""
147151
column_map = column_map or {}
148152
callbacks = callbacks or []
153+
run_config = run_config or RunConfig()
149154

150155
if helicone_config.is_enabled:
151156
import uuid
@@ -226,6 +231,7 @@ def evaluate(
226231
run_config=run_config,
227232
show_progress=show_progress,
228233
batch_size=batch_size,
234+
pbar=_pbar,
229235
)
230236

231237
# Ragas Callbacks
@@ -333,6 +339,7 @@ def evaluate(
333339
cost_cb,
334340
),
335341
ragas_traces=tracer.traces,
342+
run_id=_run_id,
336343
)
337344
if not evaluation_group_cm.ended:
338345
evaluation_rm.on_chain_end({"scores": result.scores})

src/ragas/executor.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class Executor:
8484
batch_size: t.Optional[int] = None
8585
run_config: t.Optional[RunConfig] = field(default=None, repr=False)
8686
_nest_asyncio_applied: bool = field(default=False, repr=False)
87+
pbar: t.Optional[tqdm] = None
8788

8889
def wrap_callable_with_index(
8990
self, callable: t.Callable, counter: int
@@ -130,21 +131,22 @@ async def _process_jobs(self) -> t.List[t.Any]:
130131
results = []
131132

132133
if not self.batch_size:
133-
with tqdm(
134-
total=len(self.jobs),
135-
desc=self.desc,
136-
disable=not self.show_progress,
137-
) as pbar:
138-
# Create coroutines
139-
coroutines = [
140-
afunc(*args, **kwargs) for afunc, args, kwargs, _ in self.jobs
141-
]
142-
for future in await as_completed(coroutines, max_workers):
143-
result = await future
144-
results.append(result)
145-
pbar.update(1)
134+
# Use external progress bar if provided, otherwise create one
135+
if self.pbar is None:
136+
with tqdm(
137+
total=len(self.jobs),
138+
desc=self.desc,
139+
disable=not self.show_progress,
140+
) as internal_pbar:
141+
await self._process_coroutines(
142+
self.jobs, internal_pbar, results, max_workers
143+
)
144+
else:
145+
await self._process_coroutines(
146+
self.jobs, self.pbar, results, max_workers
147+
)
146148

147-
return results
149+
return results
148150

149151
# With batching, show nested progress bars
150152
batches = batched(self.jobs, self.batch_size) # generator of job tuples
@@ -182,6 +184,14 @@ async def _process_jobs(self) -> t.List[t.Any]:
182184

183185
return results
184186

187+
async def _process_coroutines(self, jobs, pbar, results, max_workers):
188+
"""Helper function to process coroutines and update the progress bar."""
189+
coroutines = [afunc(*args, **kwargs) for afunc, args, kwargs, _ in jobs]
190+
for future in await as_completed(coroutines, max_workers):
191+
result = await future
192+
results.append(result)
193+
pbar.update(1)
194+
185195
def results(self) -> t.List[t.Any]:
186196
"""
187197
Execute all submitted jobs and return their results. The results are returned in the order of job submission.

src/ragas/losses.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import typing as t
22
from abc import ABC, abstractmethod
33

4+
from pydantic import GetCoreSchemaHandler
5+
from pydantic_core import CoreSchema, core_schema
6+
47

58
class Loss(ABC):
69
"""
@@ -11,6 +14,17 @@ class Loss(ABC):
1114
def __call__(self, predicted: t.List, actual: t.List) -> float:
1215
raise NotImplementedError
1316

17+
@classmethod
18+
def __get_pydantic_core_schema__(
19+
cls, source_type: t.Any, handler: GetCoreSchemaHandler
20+
) -> CoreSchema:
21+
"""
22+
Define how Pydantic generates a schema for BaseRagasEmbeddings.
23+
"""
24+
return core_schema.no_info_after_validator_function(
25+
cls, core_schema.is_instance_schema(cls) # The validator function
26+
)
27+
1428

1529
class MSELoss(Loss):
1630
"""

src/ragas/metrics/base.py

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@
1212

1313
from ragas._analytics import EvaluationEvent, _analytics_batcher
1414
from ragas.callbacks import ChainType, new_group
15-
from ragas.dataset_schema import MultiTurnSample, SingleTurnSample
15+
from ragas.dataset_schema import MetricAnnotation, MultiTurnSample, SingleTurnSample
1616
from ragas.executor import is_event_loop_running
17+
from ragas.losses import BinaryMetricLoss, MSELoss
1718
from ragas.prompt import PromptMixin
1819
from ragas.run_config import RunConfig
1920
from ragas.utils import (
@@ -232,12 +233,77 @@ def init(self, run_config: RunConfig):
232233
def train(
233234
self,
234235
path: str,
235-
demonstration_config: DemonstrationConfig,
236-
instruction_config: InstructionConfig,
237-
callbacks: Callbacks,
236+
demonstration_config: t.Optional[DemonstrationConfig] = None,
237+
instruction_config: t.Optional[InstructionConfig] = None,
238+
callbacks: t.Optional[Callbacks] = None,
239+
run_config: t.Optional[RunConfig] = None,
240+
batch_size: t.Optional[int] = None,
241+
with_debugging_logs=False,
242+
raise_exceptions: bool = True,
238243
) -> None:
239244

240-
raise NotImplementedError("Training is not implemented for this metric.")
245+
if not path.endswith(".json"):
246+
raise ValueError("Train data must be in json format")
247+
248+
if instruction_config is None:
249+
from ragas.config import InstructionConfig
250+
251+
instruction_config = InstructionConfig()
252+
253+
if demonstration_config is None:
254+
from ragas.config import DemonstrationConfig
255+
256+
demonstration_config = DemonstrationConfig()
257+
258+
dataset = MetricAnnotation.from_json(path, metric_name=self.name)
259+
260+
optimizer = instruction_config.optimizer
261+
llm = instruction_config.llm or self.llm
262+
if llm is None:
263+
raise ValueError(
264+
f"Metric '{self.name}' has no valid LLM provided (self.llm is None). Please initantiate a the metric with an LLM to run." # noqa
265+
)
266+
if optimizer.llm is None:
267+
optimizer.llm = llm
268+
269+
if instruction_config.loss is None:
270+
if self.output_type is None:
271+
raise ValueError(
272+
f"Output type for metric '{self.name}' is not defined. Please set the output type in the metric or in the instruction config."
273+
)
274+
275+
if self.output_type.name == MetricOutputType.BINARY.name:
276+
loss_fun = BinaryMetricLoss()
277+
elif (
278+
self.output_type.name == MetricOutputType.CONTINUOUS.name
279+
or self.output_type.name == MetricOutputType.DISCRETE.name
280+
):
281+
loss_fun = MSELoss()
282+
else:
283+
raise NotImplementedError(
284+
f"Output type '{self.output_type.name}' not implemented"
285+
)
286+
else:
287+
loss_fun = instruction_config.loss
288+
289+
optimizer.metric = self
290+
291+
optimizer_config = instruction_config.optimizer_config or {}
292+
optimized_prompts = optimizer.optimize(
293+
dataset[self.name],
294+
loss_fun,
295+
optimizer_config,
296+
callbacks=callbacks,
297+
run_config=run_config,
298+
batch_size=batch_size,
299+
with_debugging_logs=with_debugging_logs,
300+
raise_exceptions=raise_exceptions,
301+
)
302+
prompts = self.get_prompts()
303+
for key, val in optimized_prompts.items():
304+
prompts[key].instruction = val
305+
self.set_prompts(**prompts)
306+
return
241307

242308

243309
@dataclass

src/ragas/optimizers/__init__.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1-
from .base import Optimizer
1+
from ragas.optimizers.base import Optimizer
2+
from ragas.optimizers.genetic import GeneticOptimizer
23

3-
__all__ = ["Optimizer"]
4+
__all__ = [
5+
"Optimizer",
6+
"GeneticOptimizer",
7+
]

src/ragas/optimizers/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,4 +49,4 @@ def optimize(
4949
Dict[str, str]
5050
The optimized prompts for given chain.
5151
"""
52-
pass
52+
raise NotImplementedError("The method `optimize` must be implemented.")

0 commit comments

Comments
 (0)