88from dataclasses import dataclass , field
99from enum import Enum
1010
11+ from pydantic import ValidationError
1112from pysbd import Segmenter
13+ from tqdm import tqdm
1214
1315from ragas ._analytics import EvaluationEvent , _analytics_batcher
1416from ragas .callbacks import ChainType , new_group
1517from ragas .dataset_schema import MetricAnnotation , MultiTurnSample , SingleTurnSample
1618from ragas .executor import is_event_loop_running
1719from ragas .losses import BinaryMetricLoss , MSELoss
18- from ragas .prompt import PromptMixin
20+ from ragas .prompt import FewShotPydanticPrompt , PromptMixin
1921from ragas .run_config import RunConfig
2022from ragas .utils import (
2123 RAGAS_SUPPORTED_LANGUAGE_CODES ,
@@ -230,48 +232,30 @@ def init(self, run_config: RunConfig):
230232 )
231233 self .llm .set_run_config (run_config )
232234
233- def train (
235+ def _optimize_instruction (
234236 self ,
235- path : str ,
236- demonstration_config : t .Optional [DemonstrationConfig ] = None ,
237- instruction_config : t .Optional [InstructionConfig ] = None ,
238- callbacks : t .Optional [Callbacks ] = None ,
239- run_config : t .Optional [RunConfig ] = None ,
240- batch_size : t .Optional [int ] = None ,
241- with_debugging_logs = False ,
242- raise_exceptions : bool = True ,
243- ) -> None :
244-
245- if not path .endswith (".json" ):
246- raise ValueError ("Train data must be in json format" )
247-
248- if instruction_config is None :
249- from ragas .config import InstructionConfig
250-
251- instruction_config = InstructionConfig ()
252-
253- if demonstration_config is None :
254- from ragas .config import DemonstrationConfig
255-
256- demonstration_config = DemonstrationConfig ()
257-
258- dataset = MetricAnnotation .from_json (path , metric_name = self .name )
259-
260- optimizer = instruction_config .optimizer
261- llm = instruction_config .llm or self .llm
262- if llm is None :
237+ instruction_config : InstructionConfig ,
238+ dataset : MetricAnnotation ,
239+ callbacks : Callbacks ,
240+ run_config : RunConfig ,
241+ batch_size : t .Optional [int ],
242+ with_debugging_logs : bool ,
243+ raise_exceptions : bool ,
244+ ):
245+ if self .llm is None :
263246 raise ValueError (
264247 f"Metric '{ self .name } ' has no valid LLM provided (self.llm is None). Please initantiate a the metric with an LLM to run." # noqa
265248 )
249+ optimizer = instruction_config .optimizer
266250 if optimizer .llm is None :
267- optimizer .llm = llm
251+ optimizer .llm = instruction_config . llm
268252
253+ # figure out the loss function
269254 if instruction_config .loss is None :
270255 if self .output_type is None :
271256 raise ValueError (
272257 f"Output type for metric '{ self .name } ' is not defined. Please set the output type in the metric or in the instruction config."
273258 )
274-
275259 if self .output_type .name == MetricOutputType .BINARY .name :
276260 loss_fun = BinaryMetricLoss ()
277261 elif (
@@ -286,8 +270,8 @@ def train(
286270 else :
287271 loss_fun = instruction_config .loss
288272
273+ # Optimize the prompts
289274 optimizer .metric = self
290-
291275 optimizer_config = instruction_config .optimizer_config or {}
292276 optimized_prompts = optimizer .optimize (
293277 dataset [self .name ],
@@ -299,11 +283,111 @@ def train(
299283 with_debugging_logs = with_debugging_logs ,
300284 raise_exceptions = raise_exceptions ,
301285 )
286+
287+ # replace the instruction in the metric with the optimized instruction
302288 prompts = self .get_prompts ()
303289 for key , val in optimized_prompts .items ():
304290 prompts [key ].instruction = val
305291 self .set_prompts (** prompts )
306- return
292+
293+ def _optimize_demonstration (
294+ self , demonstration_config : DemonstrationConfig , dataset : MetricAnnotation
295+ ):
296+ # get the prompt annotations for this metric
297+ prompt_annotations = dataset [self .name ].get_prompt_annotations ()
298+ prompts = self .get_prompts ()
299+ for prompt_name , prompt_annotation_list in prompt_annotations .items ():
300+ # create a new FewShotPydanticPrompt with these annotations
301+ if prompt_name not in prompts :
302+ raise ValueError (
303+ f"Prompt '{ prompt_name } ' not found in metric '{ self .name } '. Please check the prompt names in the annotation dataset."
304+ )
305+ pydantic_prompt = prompts [prompt_name ]
306+ input_model , output_model = (
307+ pydantic_prompt .input_model ,
308+ pydantic_prompt .output_model ,
309+ )
310+ # convert annotations into examples
311+ input_examples , output_examples = [], []
312+ for i , prompt_annotation in enumerate (prompt_annotation_list ):
313+ try :
314+ # skip if the prompt is not accepted
315+ if not prompt_annotation .is_accepted :
316+ continue
317+ input_examples .append (
318+ input_model .model_validate (prompt_annotation .prompt_input )
319+ )
320+ # use the edited output if it is provided
321+ if prompt_annotation .edited_output is not None :
322+ output_examples .append (
323+ output_model .model_validate (prompt_annotation .edited_output )
324+ )
325+ else :
326+ output_examples .append (
327+ output_model .model_validate (prompt_annotation .prompt_output )
328+ )
329+ except ValidationError as e :
330+ logger .warning (
331+ f"Skipping prompt '{ prompt_name } ' example { i } because of validation error: { e } "
332+ )
333+ continue
334+ embedding_model = demonstration_config .embedding
335+ few_shot_prompt = FewShotPydanticPrompt .from_pydantic_prompt (
336+ pydantic_prompt = pydantic_prompt ,
337+ embeddings = embedding_model ,
338+ )
339+
340+ # add the top k examples to the few shot prompt
341+ few_shot_prompt .top_k_for_examples = demonstration_config .top_k
342+ few_shot_prompt .threshold_for_examples = demonstration_config .threshold
343+
344+ # add examples to the few shot prompt
345+ for input_example , output_example in tqdm (
346+ zip (input_examples , output_examples ),
347+ total = len (input_examples ),
348+ desc = f"Few-shot examples [{ prompt_name } ]" ,
349+ ):
350+ few_shot_prompt .add_example (input_example , output_example )
351+ prompts [prompt_name ] = few_shot_prompt
352+ self .set_prompts (** prompts )
353+
354+ def train (
355+ self ,
356+ path : str ,
357+ demonstration_config : t .Optional [DemonstrationConfig ] = None ,
358+ instruction_config : t .Optional [InstructionConfig ] = None ,
359+ callbacks : t .Optional [Callbacks ] = None ,
360+ run_config : t .Optional [RunConfig ] = None ,
361+ batch_size : t .Optional [int ] = None ,
362+ with_debugging_logs = False ,
363+ raise_exceptions : bool = True ,
364+ ) -> None :
365+ run_config = run_config or RunConfig ()
366+ callbacks = callbacks or []
367+
368+ # load the dataset from path
369+ if not path .endswith (".json" ):
370+ raise ValueError ("Train data must be in json format" )
371+ dataset = MetricAnnotation .from_json (path , metric_name = self .name )
372+
373+ # only optimize the instruction if instruction_config is provided
374+ if instruction_config is not None :
375+ self ._optimize_instruction (
376+ instruction_config = instruction_config ,
377+ dataset = dataset ,
378+ callbacks = callbacks ,
379+ run_config = run_config ,
380+ batch_size = batch_size ,
381+ with_debugging_logs = with_debugging_logs ,
382+ raise_exceptions = raise_exceptions ,
383+ )
384+
385+ # if demonstration_config is provided, optimize the demonstrations
386+ if demonstration_config is not None :
387+ self ._optimize_demonstration (
388+ demonstration_config = demonstration_config ,
389+ dataset = dataset ,
390+ )
307391
308392
309393@dataclass
0 commit comments