Skip to content

Commit 86fb753

Browse files
authored
Merge pull request #15 from amazon-science/llm-judge-by-model-and-tone
llm judge when there are multiple models and tones
2 parents 3720c2c + e17343a commit 86fb753

File tree

6 files changed

+47
-35
lines changed

6 files changed

+47
-35
lines changed

README.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,12 +106,13 @@ You can use the [CloudFormation yaml](src/cloudformation.yml) to start a Sagemak
106106
- [x] data
107107
- [x] 1. data generation -> prompt library
108108
- [x] 2.b. LLM -> implement this in a modular way in in format_prompt_as_xml
109+
- [x] merge generate_all_datasets and generate_specific_datasets
110+
- [x] add a model_router.py
111+
- [x] uv
112+
- [x] from main.py to setup.py
109113
- [ ] transfer args to settings
110-
- [ ] merge generate_all_datasets and generate_specific_datasets
111114
- [ ] batch processing for Bedrock
112115
- [ ] batch processing for Sagemaker endpoint
113-
- [x] uv
114-
- [x] from main.py to setup.py
115116
- [ ] better sagemaker inference output parsing
116117
- [x] add a model_router.py
117118
- [ ] check if model exists in settings.toml to avoid AttributeError: 'Settings' object has no attribute 'ENDPOINT_TYPE'

config/settings.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[default]
22
region = 'us-east-1'
3-
data_dir = "./data"
4-
# 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
3+
data_dir = 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
4+
# "./data"
55
deploy_bucket_name = 'llm-finetune-us-east-1-{aws_account}'
66
deploy_bucket_prefix = 'models'
77
sagemaker_execution_role_arn = 'arn:aws:iam::{aws_account}:role/sagemaker-execution-role-us-east-1'

src/wraval/actions/action_llm_judge.py

Lines changed: 36 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#
55
import pandas as pd
66
from typing import List, Dict, Optional
7+
from itertools import product
78
from dynaconf import Dynaconf
89
from .data_utils import write_dataset, load_latest_dataset
910
from .prompts_judge import generate_input_prompt, generate_system_prompt, get_rubric, rewrite_prompt
@@ -24,7 +25,7 @@ def extract_score(text: str) -> Optional[int]:
2425
match = re.search(r"<score>(\d+)</score>", text)
2526
return int(match.group(1)) if match else None
2627

27-
def validate_dataset(d: pd.DataFrame) -> bool:
28+
def validate_dataset(results: pd.DataFrame) -> bool:
2829
"""Validate required columns exist in dataset.
2930
3031
Args:
@@ -34,14 +35,14 @@ def validate_dataset(d: pd.DataFrame) -> bool:
3435
True if valid, False otherwise
3536
"""
3637
required_columns = {"synthetic_data", "rewrite", "tone"}
37-
if not all(col in d.columns for col in required_columns):
38+
if not all(col in results.columns for col in required_columns):
3839
print(f"Missing required columns. Required: {required_columns}")
3940
return False
4041
return True
4142

4243
def process_tone_data(
4344
settings: Dynaconf,
44-
d: pd.DataFrame,
45+
results: pd.DataFrame,
4546
tone: str,
4647
model_name: str,
4748
client: boto3.client,
@@ -63,14 +64,14 @@ def process_tone_data(
6364
if settings.custom_prompts == True:
6465
from wraval.custom_prompts.prompts_judge import generate_input_prompt, generate_system_prompt
6566

66-
dmt = d[d.tone == tone].copy()
67+
temp_results = results.copy()
6768
rubrics = list(tone_rubrics.keys())
6869

6970
# Generate prompts
7071
user_prompts = []
7172
sys_prompts = []
7273

73-
for q, a in zip(dmt["synthetic_data"], dmt["rewrite"]):
74+
for q, a in zip(temp_results["synthetic_data"], temp_results["rewrite"]):
7475
for rubric in rubrics:
7576
user_prompts.append(generate_input_prompt(q, a, tone))
7677
sys_prompts.append(generate_system_prompt(tone_rubrics[rubric]))
@@ -88,16 +89,16 @@ def process_tone_data(
8889

8990
# Process scores
9091
for i, rubric in enumerate(rubrics):
91-
dmt[rubric] = completions[i::len(rubrics)]
92-
dmt[f'{rubric}_score'] = dmt[rubric].apply(extract_score)
92+
temp_results[rubric] = completions[i::len(rubrics)]
93+
temp_results[f'{rubric}_score'] = temp_results[rubric].apply(extract_score)
9394

9495
# Move all score columns to the right
9596
score_columns = [f'{r}_score' for r in rubrics]
96-
other_columns = [col for col in dmt.columns if col not in score_columns]
97-
dmt = dmt[other_columns + score_columns]
97+
other_columns = [col for col in temp_results.columns if col not in score_columns]
98+
temp_results = temp_results[other_columns + score_columns]
9899

99-
dmt['overall_score'] = dmt[score_columns].mean(axis=1)
100-
return dmt
100+
temp_results['overall_score'] = temp_results[score_columns].mean(axis=1)
101+
return temp_results
101102

102103
def judge(
103104
settings: Dynaconf,
@@ -121,30 +122,39 @@ def judge(
121122
from wraval.custom_prompts.prompts_judge import get_rubric
122123

123124
try:
124-
d = load_latest_dataset(settings.data_dir)
125-
print(f"Loaded dataset with {len(d)} rows")
125+
results = load_latest_dataset(settings.data_dir)
126+
print(f"Loaded dataset with {len(results)} rows")
126127
except FileNotFoundError:
127128
print("No dataset found. Please generate data first.")
128129
return
129130

130-
if not validate_dataset(d):
131+
if not validate_dataset(results):
131132
return
132133

133-
tones = d["tone"].unique()
134+
tones = results["tone"].unique()
135+
inf_models = results["inference_model"].unique()
134136
print(f"Found tones: {tones}")
137+
print(f"Found inference_models: {inf_models}")
138+
139+
if settings.type != "all":
140+
tones = [settings.type]
135141

136-
for tone in tones:
137-
print(f"\n{'='*20}\n{tone}\n{'='*20}")
142+
# Process each tone-model combination that needs scoring
143+
for tone, inf_model in product(tones, inf_models):
144+
mask = (results.inference_model == inf_model) & (results.tone == tone)
145+
# check if any score is missing for this inference model and this tone
146+
# If yes, run the eval below
147+
if not results[mask].overall_score.isna().any():
148+
continue
149+
150+
print(f"\n{'='*20}\n{tone} tone\nfor inference model {inf_model}\n{'='*20}")
138151

139152
tone_rubrics = get_rubric(tone.upper())
140-
dmt = process_tone_data(settings, d, tone, model_name, client, tone_rubrics)
141-
142-
# Update main dataframe
143-
mask = (d.tone == tone)
144-
d.loc[mask, dmt.columns] = dmt.values
153+
temp_results = process_tone_data(settings, results[mask], tone, model_name, client, tone_rubrics)
154+
results.loc[mask, temp_results.columns] = temp_results.values
145155

146156
# Save results
147-
write_dataset(d, settings.data_dir, "all-tones", "csv")
157+
write_dataset(results, settings.data_dir, "all", "csv")
148158

149159
def rewrite_judge(
150160
model_id: str,
@@ -163,12 +173,12 @@ def rewrite_judge(
163173
Returns:
164174
DataFrame with input, output, and scores
165175
"""
166-
d = pd.DataFrame({'input': queries, 'output': answers})
176+
results = pd.DataFrame({'input': queries, 'output': answers})
167177
prompts = [rewrite_prompt(q, a) for q, a in zip(queries, answers)]
168-
d['rewrite_score'] = batch_get_bedrock_completions(
178+
results['rewrite_score'] = batch_get_bedrock_completions(
169179
model_id,
170180
bedrock_client,
171181
prompts,
172182
max_concurrent=len(prompts)
173183
)
174-
return d
184+
return results

src/wraval/actions/action_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def show_results(settings: Dynaconf, tone: Optional[str] = None) -> None:
4848
print("=" * 50)
4949

5050
# Group by model, inference model, and tone, calculate mean of overall_score
51-
grouped = d.groupby(['model', 'tone'])['overall_score'].mean()
51+
grouped = d.groupby(['inference_model', 'tone'])['overall_score'].mean()
5252

5353
# Normalize scores to 0-100 scale
5454
normalized = normalize_scores(grouped)

src/wraval/actions/data_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import os
2-
from datetime import datetime
2+
from datetime import datetime, UTC
33
import pandas as pd
44
import boto3
55
import tempfile
@@ -45,7 +45,7 @@ def write_dataset_local(
4545

4646

4747
def add_timestamp_to_file_prefix(file_prefix, format):
48-
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
48+
timestamp = datetime.now(UTC).strftime("%Y%m%d_%H%M%S")
4949
return f"{file_prefix}-{timestamp}.{format.lower()}"
5050

5151

src/wraval/actions/model_router.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,15 @@ def get_completion(self, queries: List[str]) -> List[str]:
4444
class SageMakerRouter(HuggingFaceModelRouter):
4545
def __init__(self, master_sys_prompt, settings):
4646
super().__init__(master_sys_prompt, settings)
47+
self.model_name = settings.model
4748

4849
def get_completion(self, queries: List[str]) -> List[str]:
4950
prompts = [
5051
format_prompt(text, self.master_sys_prompt, self.tokenizer, type="hf")
5152
for text in queries
5253
]
5354
return [
54-
invoke_sagemaker_endpoint({"inputs": prompt}) for prompt in tqdm(prompts)
55+
invoke_sagemaker_endpoint({"inputs": prompt}, self.model_name) for prompt in tqdm(prompts)
5556
]
5657

5758

0 commit comments

Comments
 (0)