Skip to content

Commit ca96792

Browse files
committed
warnings - typer - inference-fix
1 parent f1b2848 commit ca96792

File tree

11 files changed

+191
-132
lines changed

11 files changed

+191
-132
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,4 @@ src/wraval/custom_prompts/*
1111
src/wraval/testing.py
1212
src/wraval/model_artifacts/*
1313
!src/wraval/model_artifacts/code/
14+
build/*

config/settings.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
[default]
22
region = 'us-east-1'
3-
data_dir = 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
4-
# "./data"
3+
data_dir = "./data"
4+
# 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
55
deploy_bucket_name = 'llm-finetune-us-east-1-{aws_account}'
66
deploy_bucket_prefix = 'models'
77
sagemaker_execution_role_arn = 'arn:aws:iam::{aws_account}:role/sagemaker-execution-role-us-east-1'
8+
endpoint_type = 'bedrock'
9+
model = 'anthropic.claude-3-haiku-20240307-v1:0'
810

911
[haiku-3]
1012
model = 'anthropic.claude-3-haiku-20240307-v1:0'

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ dependencies = [
2828
"numpy",
2929
"requests",
3030
"accelerate",
31-
"torchvision"
31+
"torchvision",
32+
"typer"
3233
]
3334

3435
[project.scripts]

src/wraval/actions/action_examples.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Optional
99

1010

11-
def show_examples(settings: Dynaconf, tone: Optional[str] = None, n_examples: int = 3) -> None:
11+
def get_examples(settings: Dynaconf, tone: Optional[str] = None, n_examples: int = 3) -> None:
1212
"""
1313
Load the latest dataset and display examples grouped by tone and model.
1414

src/wraval/actions/action_inference.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,13 @@ def run_inference(
1717
"""Run inference on sentences using the specified model"""
1818
results = load_latest_dataset(data_dir)
1919

20+
no_rewrite = False
21+
2022
if "rewrite" not in results.columns:
21-
results["rewrite"] = None
22-
if "inference_model" not in results.columns:
23-
results["inference_model"] = None
23+
if "inference_model" not in results.columns:
24+
no_rewrite = True
25+
results["rewrite"] = None
26+
results["inference_model"] = None
2427

2528
tones = results["tone"].unique()
2629
print(f"Found tones: {tones}")
@@ -46,10 +49,14 @@ def run_inference(
4649
outputs = route_completion(settings, queries, tone_prompt)
4750

4851
cleaned_output = [o.strip().strip('"') for o in outputs]
49-
new_results = pd.DataFrame({"synthetic_data" : queries, "tone" : tone})
50-
new_results["rewrite"] = cleaned_output
51-
new_results["inference_model"] = model_name
52-
53-
results = pd.concat([results, new_results], ignore_index=True)
52+
if no_rewrite:
53+
mask = results["tone"] == tone
54+
results.loc[mask, "rewrite"] = cleaned_output
55+
results.loc[mask, "inference_model"] = model_name
56+
else:
57+
new_results = results[results["tone"] == tone]
58+
new_results["rewrite"] = cleaned_output
59+
new_results["inference_model"] = model_name
60+
results = pd.concat([results, new_results], ignore_index=True)
5461

5562
write_dataset(results, data_dir, "all", "csv")

src/wraval/actions/action_llm_judge.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,31 @@
33
# // SPDX-License-Identifier: Apache-2.0
44
#
55
import pandas as pd
6-
from typing import List, Dict, Optional
6+
from typing import List, Dict, Any, Optional
77
from itertools import product
88
from dynaconf import Dynaconf
99
from .data_utils import write_dataset, load_latest_dataset
10-
from .prompts_judge import generate_input_prompt, generate_system_prompt, get_rubric, rewrite_prompt
11-
1210
from .completion import batch_get_bedrock_completions
1311
import re
1412
import boto3
1513

14+
# Import prompt functions based on settings
15+
def get_prompt_functions(settings: Dynaconf):
16+
"""Get the appropriate prompt functions based on settings."""
17+
if settings.custom_prompts:
18+
from wraval.custom_prompts.prompts_judge import (
19+
generate_input_prompt,
20+
generate_system_prompt,
21+
get_rubric
22+
)
23+
else:
24+
from .prompts_judge import (
25+
generate_input_prompt,
26+
generate_system_prompt,
27+
get_rubric
28+
)
29+
return generate_input_prompt, generate_system_prompt, get_rubric
30+
1631
def extract_score(text: str) -> Optional[int]:
1732
"""Extract score from text using regex pattern.
1833
@@ -60,9 +75,8 @@ def process_tone_data(
6075
Returns:
6176
Processed DataFrame with scores
6277
"""
63-
64-
if settings.custom_prompts == True:
65-
from wraval.custom_prompts.prompts_judge import generate_input_prompt, generate_system_prompt
78+
# Get the appropriate prompt functions
79+
generate_input_prompt, generate_system_prompt, _ = get_prompt_functions(settings)
6680

6781
temp_results = results.copy()
6882
rubrics = list(tone_rubrics.keys())
@@ -118,9 +132,6 @@ def judge(
118132
endpoint_type: Type of endpoint to use
119133
"""
120134

121-
if settings.custom_prompts == True:
122-
from wraval.custom_prompts.prompts_judge import get_rubric
123-
124135
try:
125136
results = load_latest_dataset(settings.data_dir)
126137
print(f"Loaded dataset with {len(results)} rows")
@@ -139,11 +150,15 @@ def judge(
139150
if settings.type != "all":
140151
tones = [settings.type]
141152

153+
# Get the appropriate prompt functions
154+
_, _, get_rubric = get_prompt_functions(settings)
155+
142156
# Process each tone-model combination that needs scoring
143157
for tone, inf_model in product(tones, inf_models):
144158
mask = (results.inference_model == inf_model) & (results.tone == tone)
145159
# check if any score is missing for this inference model and this tone
146-
# If yes, run the eval below
160+
if 'overall_score' not in results.columns:
161+
results['overall_score'] = None
147162
if not results[mask].overall_score.isna().any():
148163
continue
149164

src/wraval/actions/action_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def normalize_scores(d: pd.DataFrame) -> pd.DataFrame:
2121
return 100 * (d - 1) / 2
2222

2323

24-
def show_results(settings: Dynaconf, tone: Optional[str] = None) -> None:
24+
def get_results(settings: Dynaconf, tone: Optional[str] = None) -> None:
2525
"""
2626
Load the latest dataset and display normalized results table grouped by tone.
2727

src/wraval/actions/data_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,4 +122,5 @@ def load_latest_dataset(data_dir: str) -> pd.DataFrame:
122122
raise FileNotFoundError(f"No CSV files found in {data_dir}")
123123

124124
file_path = sorted(files, reverse=True)[0]
125+
print(f'Loading {file_path}')
125126
return pd.read_csv(os.path.join(data_dir, file_path))

src/wraval/actions/format.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,9 @@ def format_prompt(usr_prompt, prompt=None, tokenizer=None, type = 'bedrock'):
3939
messages = []
4040
if prompt.examples:
4141
for k,v in prompt.examples[0].items():
42-
messages.extend([{"role": k, "content": v}])
42+
# Format each message content as a list of text blocks
43+
messages.extend([{"role": k, "content": [{"text": v}]}])
44+
# Format user prompt as a list of text blocks
4345
usr_prompt = [{"role": "user", "content": [{"text": usr_prompt}]}]
4446
p = messages + usr_prompt
4547
else:

src/wraval/aws_config.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#
2+
# // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
# // SPDX-License-Identifier: Apache-2.0
4+
#
5+
import os
6+
import logging
7+
import warnings
8+
9+
# Suppress Pydantic warning
10+
warnings.filterwarnings("ignore", message="Field name \"json\" in \"MonitoringDatasetFormat\" shadows an attribute in parent \"Base\"")
11+
12+
# Configure logging before any AWS imports
13+
logging.getLogger('sagemaker').setLevel(logging.ERROR)
14+
logging.getLogger('sagemaker.config').setLevel(logging.ERROR) # Specifically target the config module
15+
logging.getLogger('boto3').setLevel(logging.ERROR)
16+
logging.getLogger('botocore').setLevel(logging.ERROR)
17+
logging.getLogger('urllib3').setLevel(logging.ERROR)
18+
19+
# Suppress AWS credential messages
20+
os.environ['SAGEMAKER_SUPPRESS_DEFAULTS'] = 'true'
21+
os.environ['AWS_SDK_LOAD_CONFIG'] = '0' # Suppress AWS SDK config loading messages
22+
23+
# Now import AWS modules
24+
import boto3
25+
import sagemaker

0 commit comments

Comments
 (0)