Skip to content

Commit b30591a

Browse files
committed
llm judge
1 parent 30bb13f commit b30591a

File tree

6 files changed

+20
-19
lines changed

6 files changed

+20
-19
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ src/wraval.egg-info/
77
**__pycache__/
88
prompts/*
99
.idea
10+
src/wraval/custom_prompts/*

config/settings.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[default]
22
region = 'us-east-1'
3-
data_dir = 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
3+
data_dir = "./data"
4+
# 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
45

56
[haiku-3]
67
model = 'anthropic.claude-3-haiku-20240307-v1:0'

src/wraval/actions/action_llm_judge.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def validate_dataset(d: pd.DataFrame) -> bool:
3939
return True
4040

4141
def process_tone_data(
42+
settings: Dynaconf,
4243
d: pd.DataFrame,
4344
tone: str,
4445
model_name: str,
@@ -61,22 +62,21 @@ def process_tone_data(
6162
rubrics = list(tone_rubrics.keys())
6263

6364
# Generate prompts
64-
prompts = []
65+
user_prompts = []
66+
sys_prompts = []
67+
6568
for q, a in zip(dmt["synthetic_data"], dmt["rewrite"]):
6669
for rubric in rubrics:
67-
prompts.append((
68-
generate_system_prompt(tone_rubrics[rubric]),
69-
generate_input_prompt(q, a, tone)
70-
))
70+
user_prompts.append(generate_input_prompt(q, a, tone))
71+
sys_prompts.append(generate_system_prompt(tone_rubrics[rubric]))
7172

7273
# Get completions
73-
sys_prompts, user_prompts = zip(*prompts)
74+
# import pdb
75+
# pdb.set_trace()
7476
completions = batch_get_bedrock_completions(
75-
model_name,
76-
client,
77+
settings,
7778
user_prompts,
78-
sys_prompts,
79-
max_concurrent=len(user_prompts)
79+
sys_prompts
8080
)
8181

8282
rubrics = [r.lower() for r in rubrics]
@@ -99,7 +99,6 @@ def judge(
9999
client: boto3.client,
100100
model_name: str,
101101
upload_s3: bool,
102-
data_dir: str,
103102
endpoint_type: str = "bedrock"
104103
) -> None:
105104
"""Judge rewrites using specified model and rubrics.
@@ -113,7 +112,7 @@ def judge(
113112
endpoint_type: Type of endpoint to use
114113
"""
115114
try:
116-
d = load_latest_dataset(data_dir)
115+
d = load_latest_dataset(settings.data_dir)
117116
print(f"Loaded dataset with {len(d)} rows")
118117
except FileNotFoundError:
119118
print("No dataset found. Please generate data first.")
@@ -129,7 +128,7 @@ def judge(
129128
print(f"\n{'='*20}\n{tone}\n{'='*20}")
130129

131130
tone_rubrics = get_rubric(tone.upper())
132-
dmt = process_tone_data(d, tone, model_name, client, tone_rubrics)
131+
dmt = process_tone_data(settings, d, tone, model_name, client, tone_rubrics)
133132

134133
# Update main dataframe
135134
mask = (d.tone == tone)

src/wraval/actions/completion.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def get_bedrock_completion(settings, prompt, system_prompt=None):
5252
)
5353

5454
if isinstance(system_prompt, str) and len(system_prompt) > 0:
55-
# converse_api_params.update({"system": [{"text": system_prompt}]})
56-
converse_api_params["messages"] = [{"role": "assistant", "content": [{"text": system_prompt}]}] + converse_api_params["messages"]
55+
converse_api_params.update({"system": [{"text": system_prompt}]})
56+
# converse_api_params["messages"] = [{"role": "assistant", "content": [{"text": system_prompt}]}] + converse_api_params["messages"]
5757

5858
response = bedrock_client.converse(**converse_api_params)
5959
return response['output']['message']['content'][0]['text']

src/wraval/actions/data_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def load_latest_dataset(data_dir: str) -> pd.DataFrame:
9797
bucket, prefix = parse_s3_path(data_dir)
9898
return load_latest_dataset_from_s3(bucket, prefix)
9999
else:
100+
100101
# Local file handling
101102
data_dir = os.path.expanduser(data_dir)
102103

src/wraval/main.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def handle_inference(args, settings):
8686

8787

8888
def handle_judge(args, settings):
89-
if args.endpoint_type == "bedrock":
89+
if settings.endpoint_type == "bedrock":
9090
judge_model = settings.model
9191
client = boto3.client(
9292
service_name="bedrock-runtime", region_name=settings.region
@@ -100,8 +100,7 @@ def handle_judge(args, settings):
100100
client,
101101
judge_model,
102102
args.upload_s3,
103-
settings.data_dir,
104-
args.endpoint_type,
103+
settings.endpoint_type,
105104
)
106105

107106

0 commit comments

Comments
 (0)