44#
55import pandas as pd
66from typing import List , Dict , Optional
7+ from itertools import product
78from dynaconf import Dynaconf
89from .data_utils import write_dataset , load_latest_dataset
910from .prompts_judge import generate_input_prompt , generate_system_prompt , get_rubric , rewrite_prompt
@@ -24,7 +25,7 @@ def extract_score(text: str) -> Optional[int]:
2425 match = re .search (r"<score>(\d+)</score>" , text )
2526 return int (match .group (1 )) if match else None
2627
27- def validate_dataset (d : pd .DataFrame ) -> bool :
28+ def validate_dataset (results : pd .DataFrame ) -> bool :
2829 """Validate required columns exist in dataset.
2930
3031 Args:
@@ -34,14 +35,14 @@ def validate_dataset(d: pd.DataFrame) -> bool:
3435 True if valid, False otherwise
3536 """
3637 required_columns = {"synthetic_data" , "rewrite" , "tone" }
37- if not all (col in d .columns for col in required_columns ):
38+ if not all (col in results .columns for col in required_columns ):
3839 print (f"Missing required columns. Required: { required_columns } " )
3940 return False
4041 return True
4142
4243def process_tone_data (
4344 settings : Dynaconf ,
44- d : pd .DataFrame ,
45+ results : pd .DataFrame ,
4546 tone : str ,
4647 model_name : str ,
4748 client : boto3 .client ,
@@ -63,14 +64,14 @@ def process_tone_data(
6364 if settings .custom_prompts == True :
6465 from wraval .custom_prompts .prompts_judge import generate_input_prompt , generate_system_prompt
6566
66- dmt = d [ d . tone == tone ] .copy ()
67+ temp_results = results .copy ()
6768 rubrics = list (tone_rubrics .keys ())
6869
6970 # Generate prompts
7071 user_prompts = []
7172 sys_prompts = []
7273
73- for q , a in zip (dmt ["synthetic_data" ], dmt ["rewrite" ]):
74+ for q , a in zip (temp_results ["synthetic_data" ], temp_results ["rewrite" ]):
7475 for rubric in rubrics :
7576 user_prompts .append (generate_input_prompt (q , a , tone ))
7677 sys_prompts .append (generate_system_prompt (tone_rubrics [rubric ]))
@@ -88,16 +89,16 @@ def process_tone_data(
8889
8990 # Process scores
9091 for i , rubric in enumerate (rubrics ):
91- dmt [rubric ] = completions [i ::len (rubrics )]
92- dmt [f'{ rubric } _score' ] = dmt [rubric ].apply (extract_score )
92+ temp_results [rubric ] = completions [i ::len (rubrics )]
93+ temp_results [f'{ rubric } _score' ] = temp_results [rubric ].apply (extract_score )
9394
9495 # Move all score columns to the right
9596 score_columns = [f'{ r } _score' for r in rubrics ]
96- other_columns = [col for col in dmt .columns if col not in score_columns ]
97- dmt = dmt [other_columns + score_columns ]
97+ other_columns = [col for col in temp_results .columns if col not in score_columns ]
98+ temp_results = temp_results [other_columns + score_columns ]
9899
99- dmt ['overall_score' ] = dmt [score_columns ].mean (axis = 1 )
100- return dmt
100+ temp_results ['overall_score' ] = temp_results [score_columns ].mean (axis = 1 )
101+ return temp_results
101102
102103def judge (
103104 settings : Dynaconf ,
@@ -121,30 +122,39 @@ def judge(
121122 from wraval .custom_prompts .prompts_judge import get_rubric
122123
123124 try :
124- d = load_latest_dataset (settings .data_dir )
125- print (f"Loaded dataset with { len (d )} rows" )
125+ results = load_latest_dataset (settings .data_dir )
126+ print (f"Loaded dataset with { len (results )} rows" )
126127 except FileNotFoundError :
127128 print ("No dataset found. Please generate data first." )
128129 return
129130
130- if not validate_dataset (d ):
131+ if not validate_dataset (results ):
131132 return
132133
133- tones = d ["tone" ].unique ()
134+ tones = results ["tone" ].unique ()
135+ inf_models = results ["inference_model" ].unique ()
134136 print (f"Found tones: { tones } " )
137+ print (f"Found inference_models: { inf_models } " )
138+
139+ if settings .type != "all" :
140+ tones = [settings .type ]
135141
136- for tone in tones :
137- print (f"\n { '=' * 20 } \n { tone } \n { '=' * 20 } " )
142+ # Process each tone-model combination that needs scoring
143+ for tone , inf_model in product (tones , inf_models ):
144+ mask = (results .inference_model == inf_model ) & (results .tone == tone )
145+ # check if any score is missing for this inference model and this tone
146+ # If yes, run the eval below
147+ if not results [mask ].overall_score .isna ().any ():
148+ continue
149+
150+ print (f"\n { '=' * 20 } \n { tone } tone\n for inference model { inf_model } \n { '=' * 20 } " )
138151
139152 tone_rubrics = get_rubric (tone .upper ())
140- dmt = process_tone_data (settings , d , tone , model_name , client , tone_rubrics )
141-
142- # Update main dataframe
143- mask = (d .tone == tone )
144- d .loc [mask , dmt .columns ] = dmt .values
153+ temp_results = process_tone_data (settings , results [mask ], tone , model_name , client , tone_rubrics )
154+ results .loc [mask , temp_results .columns ] = temp_results .values
145155
146156 # Save results
147- write_dataset (d , settings .data_dir , "all-tones " , "csv" )
157+ write_dataset (results , settings .data_dir , "all" , "csv" )
148158
149159def rewrite_judge (
150160 model_id : str ,
@@ -163,12 +173,12 @@ def rewrite_judge(
163173 Returns:
164174 DataFrame with input, output, and scores
165175 """
166- d = pd .DataFrame ({'input' : queries , 'output' : answers })
176+ results = pd .DataFrame ({'input' : queries , 'output' : answers })
167177 prompts = [rewrite_prompt (q , a ) for q , a in zip (queries , answers )]
168- d ['rewrite_score' ] = batch_get_bedrock_completions (
178+ results ['rewrite_score' ] = batch_get_bedrock_completions (
169179 model_id ,
170180 bedrock_client ,
171181 prompts ,
172182 max_concurrent = len (prompts )
173183 )
174- return d
184+ return results
0 commit comments