Skip to content

Commit d432ed0

Browse files
authored
feat: few shot example optimzier (#1739)
optimize with few short examples ```py from ragas.metrics import AspectCritic from ragas.llms import llm_factory # define metric llm = llm_factory("gpt-4o") metric = AspectCritic( name="answer_correctness", definition="Given the user_input, reference and response. Is the response correct compared with the reference", llm=llm, ) # optimize with annotation from ragas.config import DemonstrationConfig demonstration_config = DemonstrationConfig() metric.train( "alignment_sample.json", demonstration_config=demonstration_config, ) ```
1 parent 9f5cccc commit d432ed0

File tree

9 files changed

+334
-57
lines changed

9 files changed

+334
-57
lines changed

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,3 +91,9 @@ addopts = "-n 0"
9191
asyncio_default_fixture_loop_scope = "function"
9292
[pytest]
9393
testpaths = ["tests"]
94+
95+
[dependency-groups]
96+
dev = [
97+
"arize-phoenix>=6.1.0",
98+
"openinference-instrumentation-langchain>=0.1.29",
99+
]

src/ragas/config.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
1+
from __future__ import annotations
2+
13
import typing as t
24

3-
from pydantic import BaseModel, Field
5+
from pydantic import BaseModel, Field, field_validator
46

5-
from ragas.embeddings import BaseRagasEmbeddings
6-
from ragas.llms import BaseRagasLLM
7+
from ragas.embeddings.base import BaseRagasEmbeddings
8+
from ragas.llms.base import BaseRagasLLM
79
from ragas.losses import Loss
810
from ragas.optimizers import GeneticOptimizer, Optimizer
911

1012
DEFAULT_OPTIMIZER_CONFIG = {"max_steps": 100}
1113

1214

1315
class DemonstrationConfig(BaseModel):
16+
embedding: t.Any # this has to be of type Any because BaseRagasEmbedding is an ABC
1417
enabled: bool = True
1518
top_k: int = 3
19+
threshold: float = 0.7
1620
technique: t.Literal["random", "similarity"] = "similarity"
17-
embedding: t.Optional[BaseRagasEmbeddings] = None
21+
22+
@field_validator("embedding")
23+
def validate_embedding(cls, v):
24+
if not isinstance(v, BaseRagasEmbeddings):
25+
raise ValueError("embedding must be an instance of BaseRagasEmbeddings")
26+
return v
1827

1928

2029
class InstructionConfig(BaseModel):
30+
llm: BaseRagasLLM
2131
enabled: bool = True
2232
loss: t.Optional[Loss] = None
2333
optimizer: Optimizer = GeneticOptimizer()
2434
optimizer_config: t.Dict[str, t.Any] = Field(
2535
default_factory=lambda: DEFAULT_OPTIMIZER_CONFIG
2636
)
27-
llm: t.Optional[BaseRagasLLM] = None
37+
38+
39+
InstructionConfig.model_rebuild()

src/ragas/dataset_schema.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ class PromptAnnotation(BaseModel):
554554
prompt_input: t.Dict[str, t.Any]
555555
prompt_output: t.Dict[str, t.Any]
556556
is_accepted: bool
557-
edited_output: t.Union[t.Dict[str, t.Any], None]
557+
edited_output: t.Optional[t.Dict[str, t.Any]] = None
558558

559559
def __getitem__(self, key):
560560
return getattr(self, key)
@@ -801,3 +801,13 @@ def stratified_batches(
801801
all_batches.append(batch)
802802

803803
return all_batches
804+
805+
def get_prompt_annotations(self) -> t.Dict[str, t.List[PromptAnnotation]]:
806+
"""
807+
Get all the prompt annotations for each prompt as a list.
808+
"""
809+
prompt_annotations = defaultdict(list)
810+
for sample in self.samples:
811+
for prompt_name, prompt_annotation in sample.prompts.items():
812+
prompt_annotations[prompt_name].append(prompt_annotation)
813+
return prompt_annotations

src/ragas/metrics/base.py

Lines changed: 118 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,16 @@
88
from dataclasses import dataclass, field
99
from enum import Enum
1010

11+
from pydantic import ValidationError
1112
from pysbd import Segmenter
13+
from tqdm import tqdm
1214

1315
from ragas._analytics import EvaluationEvent, _analytics_batcher
1416
from ragas.callbacks import ChainType, new_group
1517
from ragas.dataset_schema import MetricAnnotation, MultiTurnSample, SingleTurnSample
1618
from ragas.executor import is_event_loop_running
1719
from ragas.losses import BinaryMetricLoss, MSELoss
18-
from ragas.prompt import PromptMixin
20+
from ragas.prompt import FewShotPydanticPrompt, PromptMixin
1921
from ragas.run_config import RunConfig
2022
from ragas.utils import (
2123
RAGAS_SUPPORTED_LANGUAGE_CODES,
@@ -230,48 +232,30 @@ def init(self, run_config: RunConfig):
230232
)
231233
self.llm.set_run_config(run_config)
232234

233-
def train(
235+
def _optimize_instruction(
234236
self,
235-
path: str,
236-
demonstration_config: t.Optional[DemonstrationConfig] = None,
237-
instruction_config: t.Optional[InstructionConfig] = None,
238-
callbacks: t.Optional[Callbacks] = None,
239-
run_config: t.Optional[RunConfig] = None,
240-
batch_size: t.Optional[int] = None,
241-
with_debugging_logs=False,
242-
raise_exceptions: bool = True,
243-
) -> None:
244-
245-
if not path.endswith(".json"):
246-
raise ValueError("Train data must be in json format")
247-
248-
if instruction_config is None:
249-
from ragas.config import InstructionConfig
250-
251-
instruction_config = InstructionConfig()
252-
253-
if demonstration_config is None:
254-
from ragas.config import DemonstrationConfig
255-
256-
demonstration_config = DemonstrationConfig()
257-
258-
dataset = MetricAnnotation.from_json(path, metric_name=self.name)
259-
260-
optimizer = instruction_config.optimizer
261-
llm = instruction_config.llm or self.llm
262-
if llm is None:
237+
instruction_config: InstructionConfig,
238+
dataset: MetricAnnotation,
239+
callbacks: Callbacks,
240+
run_config: RunConfig,
241+
batch_size: t.Optional[int],
242+
with_debugging_logs: bool,
243+
raise_exceptions: bool,
244+
):
245+
if self.llm is None:
263246
raise ValueError(
264247
f"Metric '{self.name}' has no valid LLM provided (self.llm is None). Please initantiate a the metric with an LLM to run." # noqa
265248
)
249+
optimizer = instruction_config.optimizer
266250
if optimizer.llm is None:
267-
optimizer.llm = llm
251+
optimizer.llm = instruction_config.llm
268252

253+
# figure out the loss function
269254
if instruction_config.loss is None:
270255
if self.output_type is None:
271256
raise ValueError(
272257
f"Output type for metric '{self.name}' is not defined. Please set the output type in the metric or in the instruction config."
273258
)
274-
275259
if self.output_type.name == MetricOutputType.BINARY.name:
276260
loss_fun = BinaryMetricLoss()
277261
elif (
@@ -286,8 +270,8 @@ def train(
286270
else:
287271
loss_fun = instruction_config.loss
288272

273+
# Optimize the prompts
289274
optimizer.metric = self
290-
291275
optimizer_config = instruction_config.optimizer_config or {}
292276
optimized_prompts = optimizer.optimize(
293277
dataset[self.name],
@@ -299,11 +283,111 @@ def train(
299283
with_debugging_logs=with_debugging_logs,
300284
raise_exceptions=raise_exceptions,
301285
)
286+
287+
# replace the instruction in the metric with the optimized instruction
302288
prompts = self.get_prompts()
303289
for key, val in optimized_prompts.items():
304290
prompts[key].instruction = val
305291
self.set_prompts(**prompts)
306-
return
292+
293+
def _optimize_demonstration(
294+
self, demonstration_config: DemonstrationConfig, dataset: MetricAnnotation
295+
):
296+
# get the prompt annotations for this metric
297+
prompt_annotations = dataset[self.name].get_prompt_annotations()
298+
prompts = self.get_prompts()
299+
for prompt_name, prompt_annotation_list in prompt_annotations.items():
300+
# create a new FewShotPydanticPrompt with these annotations
301+
if prompt_name not in prompts:
302+
raise ValueError(
303+
f"Prompt '{prompt_name}' not found in metric '{self.name}'. Please check the prompt names in the annotation dataset."
304+
)
305+
pydantic_prompt = prompts[prompt_name]
306+
input_model, output_model = (
307+
pydantic_prompt.input_model,
308+
pydantic_prompt.output_model,
309+
)
310+
# convert annotations into examples
311+
input_examples, output_examples = [], []
312+
for i, prompt_annotation in enumerate(prompt_annotation_list):
313+
try:
314+
# skip if the prompt is not accepted
315+
if not prompt_annotation.is_accepted:
316+
continue
317+
input_examples.append(
318+
input_model.model_validate(prompt_annotation.prompt_input)
319+
)
320+
# use the edited output if it is provided
321+
if prompt_annotation.edited_output is not None:
322+
output_examples.append(
323+
output_model.model_validate(prompt_annotation.edited_output)
324+
)
325+
else:
326+
output_examples.append(
327+
output_model.model_validate(prompt_annotation.prompt_output)
328+
)
329+
except ValidationError as e:
330+
logger.warning(
331+
f"Skipping prompt '{prompt_name}' example {i} because of validation error: {e}"
332+
)
333+
continue
334+
embedding_model = demonstration_config.embedding
335+
few_shot_prompt = FewShotPydanticPrompt.from_pydantic_prompt(
336+
pydantic_prompt=pydantic_prompt,
337+
embeddings=embedding_model,
338+
)
339+
340+
# add the top k examples to the few shot prompt
341+
few_shot_prompt.top_k_for_examples = demonstration_config.top_k
342+
few_shot_prompt.threshold_for_examples = demonstration_config.threshold
343+
344+
# add examples to the few shot prompt
345+
for input_example, output_example in tqdm(
346+
zip(input_examples, output_examples),
347+
total=len(input_examples),
348+
desc=f"Few-shot examples [{prompt_name}]",
349+
):
350+
few_shot_prompt.add_example(input_example, output_example)
351+
prompts[prompt_name] = few_shot_prompt
352+
self.set_prompts(**prompts)
353+
354+
def train(
355+
self,
356+
path: str,
357+
demonstration_config: t.Optional[DemonstrationConfig] = None,
358+
instruction_config: t.Optional[InstructionConfig] = None,
359+
callbacks: t.Optional[Callbacks] = None,
360+
run_config: t.Optional[RunConfig] = None,
361+
batch_size: t.Optional[int] = None,
362+
with_debugging_logs=False,
363+
raise_exceptions: bool = True,
364+
) -> None:
365+
run_config = run_config or RunConfig()
366+
callbacks = callbacks or []
367+
368+
# load the dataset from path
369+
if not path.endswith(".json"):
370+
raise ValueError("Train data must be in json format")
371+
dataset = MetricAnnotation.from_json(path, metric_name=self.name)
372+
373+
# only optimize the instruction if instruction_config is provided
374+
if instruction_config is not None:
375+
self._optimize_instruction(
376+
instruction_config=instruction_config,
377+
dataset=dataset,
378+
callbacks=callbacks,
379+
run_config=run_config,
380+
batch_size=batch_size,
381+
with_debugging_logs=with_debugging_logs,
382+
raise_exceptions=raise_exceptions,
383+
)
384+
385+
# if demonstration_config is provided, optimize the demonstrations
386+
if demonstration_config is not None:
387+
self._optimize_demonstration(
388+
demonstration_config=demonstration_config,
389+
dataset=dataset,
390+
)
307391

308392

309393
@dataclass

0 commit comments

Comments
 (0)