Skip to content

Commit 08050df

Browse files
authored
Merge pull request #13 from amazon-science/agnostic-write-dataset
write_dataset that calls either s3 or local
2 parents b61c18b + 350b7ea commit 08050df

File tree

8 files changed

+51
-49
lines changed

8 files changed

+51
-49
lines changed

config/settings.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[default]
22
region = 'us-east-1'
3-
data_dir = 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
3+
data_dir = "./data"
4+
# 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
45
deploy_bucket_name = 'llm-finetune-us-east-1-{aws_account}'
56
deploy_bucket_prefix = 'models'
67
sagemaker_execution_role_arn = 'arn:aws:iam::{aws_account}:role/sagemaker-execution-role-us-east-1'

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,4 @@ where = ["src"]
4646

4747
[tool.setuptools.data-files]
4848
# This copies the config file into the installation (non-package dir)
49-
"config" = ["config/settings.toml"]
50-
49+
"config" = ["config/settings.toml"]

src/wraval/actions/action_generate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# // SPDX-License-Identifier: Apache-2.0
44
#
55
import pandas as pd
6-
from .data_utils import write_dataset_local, write_dataset_to_s3
6+
from .data_utils import write_dataset
77
from dynaconf import Dynaconf
88
from .prompt_tones import get_all_tones, Tone
99
import os
@@ -79,6 +79,4 @@ def generate_tone_data(
7979

8080
combined = pd.concat(datasets, ignore_index=True)
8181

82-
write_dataset_local(combined, settings.data_dir, "all-tones")
83-
if upload_s3:
84-
write_dataset_to_s3(combined, settings.s3_bucket, "generate/all", "csv")
82+
write_dataset(combined, settings.data_dir, "all", "csv")

src/wraval/actions/action_inference.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
33
# // SPDX-License-Identifier: Apache-2.0
44
#
5+
import pandas as pd
56
from dynaconf import Dynaconf
6-
from .data_utils import write_dataset_local, write_dataset_to_s3, load_latest_dataset
7+
from .data_utils import write_dataset, load_latest_dataset
78
from .prompt_tones import get_prompt, Tone
89
from .model_router import route_completion
910

@@ -14,19 +15,14 @@ def run_inference(
1415
data_dir: str
1516
) -> None:
1617
"""Run inference on sentences using the specified model"""
17-
try:
18-
d = load_latest_dataset(data_dir)
19-
print(f"Loaded dataset with {len(d)} rows")
20-
except FileNotFoundError:
21-
print("No dataset found. Please generate data first.")
22-
return
23-
24-
if "rewrite" not in d.columns:
25-
d["rewrite"] = None
26-
if "inference_model" not in d.columns:
27-
d["inference_model"] = None
28-
29-
tones = d["tone"].unique()
18+
results = load_latest_dataset(data_dir)
19+
20+
if "rewrite" not in results.columns:
21+
results["rewrite"] = None
22+
if "inference_model" not in results.columns:
23+
results["inference_model"] = None
24+
25+
tones = results["tone"].unique()
3026
print(f"Found tones: {tones}")
3127

3228
if settings.type != "all":
@@ -43,18 +39,17 @@ def run_inference(
4339

4440
tone_prompt = get_prompt(Tone(tone))
4541

46-
queries = d[d["tone"] == tone]["synthetic_data"].unique()
42+
queries = results[results["tone"] == tone]["synthetic_data"].unique()
4743

4844
print(f"Processing {len(queries)} unique inputs for tone: {tone}")
4945

5046
outputs = route_completion(settings, queries, tone_prompt)
5147

52-
for query, output in zip(queries, outputs):
53-
mask = (d["synthetic_data"] == query) & (d["tone"] == tone)
54-
cleaned_output = output.strip().strip('"')
55-
d.loc[mask, "rewrite"] = cleaned_output
56-
d.loc[mask, "inference_model"] = model_name
48+
cleaned_output = [o.strip().strip('"') for o in outputs]
49+
new_results = pd.DataFrame({"synthetic_data" : queries, "tone" : tone})
50+
new_results["rewrite"] = cleaned_output
51+
new_results["inference_model"] = model_name
52+
53+
results = pd.concat([results, new_results], ignore_index=True)
5754

58-
write_dataset_local(d, "./data", "all-tones")
59-
if upload_s3:
60-
write_dataset_to_s3(d, settings.s3_bucket, "inference/all", "csv")
55+
write_dataset(results, data_dir, "all", "csv")

src/wraval/actions/action_llm_judge.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pandas as pd
66
from typing import List, Dict, Optional
77
from dynaconf import Dynaconf
8-
from .data_utils import write_dataset_local, write_dataset_to_s3, load_latest_dataset
8+
from .data_utils import write_dataset, load_latest_dataset
99
from .prompts_judge import generate_input_prompt, generate_system_prompt, get_rubric, rewrite_prompt
1010

1111
from .completion import batch_get_bedrock_completions
@@ -144,9 +144,7 @@ def judge(
144144
d.loc[mask, dmt.columns] = dmt.values
145145

146146
# Save results
147-
write_dataset_local(d, "./data", "all-tones")
148-
if upload_s3:
149-
write_dataset_to_s3(d, settings.s3_bucket, "inference/all", "csv")
147+
write_dataset(d, settings.data_dir, "all-tones", "csv")
150148

151149
def rewrite_judge(
152150
model_id: str,

src/wraval/actions/data_utils.py

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,30 +7,41 @@
77
from urllib.parse import urlparse
88

99

10-
def write_dataset_to_s3(
11-
df: pd.DataFrame, bucket: str, key_prefix: str, format: str
12-
) -> str:
10+
def write_dataset(
11+
df: pd.DataFrame, data_dir: str, file_prefix: str, format: str
12+
):
13+
if is_s3_path(data_dir):
14+
bucket, prefix = parse_s3_path(data_dir)
15+
write_dataset_s3(df, bucket, prefix, file_prefix, format)
16+
else:
17+
write_dataset_local(df, data_dir, file_prefix, format)
18+
19+
def write_dataset_s3(
20+
df: pd.DataFrame, bucket: str, prefix: str, file_prefix: str, format: str
21+
):
1322
with tempfile.TemporaryDirectory() as temp_dir:
14-
temp_file = os.path.join(temp_dir, "temp.jsonl")
15-
df.to_json(temp_file, orient="records", lines=bool(format == "jsonl"))
23+
temp_file = os.path.join(temp_dir, "temp.csv")
24+
df.to_csv(temp_file, index=False)
1625
s3_client = boto3.client("s3")
17-
key = add_timestamp_to_file_prefix(key_prefix, format)
18-
print(f"Writing dataset to bucket {bucket} and key {key}.")
26+
key = os.path.join(prefix,
27+
add_timestamp_to_file_prefix(file_prefix, format)
28+
)
29+
print(f"Writing dataset to s3://{bucket}/{key}")
1930
s3_client.upload_file(temp_file, bucket, key)
20-
return f"s3://{bucket}/{key}"
21-
2231

23-
def write_dataset_local(df: pd.DataFrame, data_dir: str, file_prefix: str) -> str:
32+
def write_dataset_local(
33+
df: pd.DataFrame, data_dir: str, file_prefix: str, format: str
34+
) -> str:
2435
# Expand home directory and create if needed
2536
data_dir = os.path.expanduser(data_dir)
2637
os.makedirs(data_dir, exist_ok=True)
2738

2839
output_path = os.path.join(
29-
data_dir, add_timestamp_to_file_prefix(file_prefix, "csv")
40+
data_dir,
41+
add_timestamp_to_file_prefix(file_prefix, format)
3042
)
3143
df.to_csv(output_path, index=False)
32-
print(f"Saved to {output_path}")
33-
return output_path
44+
print(f"Saved locally to {output_path}")
3445

3546

3647
def add_timestamp_to_file_prefix(file_prefix, format):

src/wraval/actions/model_router.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
)
66
from .format import format_prompt
77
from transformers import AutoTokenizer
8-
from tqdm import tqdm
8+
from tqdm.auto import tqdm
99
from typing import List
1010
from dynaconf.base import LazySettings
1111
from abc import ABC, abstractmethod

src/wraval/model_artifacts/code/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ def predict_fn(data, model_and_tokenizer, *args):
2727
model.config.pad_token_id = model.config.eos_token_id
2828
inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(model.device)
2929
output_sequences = model.generate(**inputs, max_new_tokens=1024)
30-
return tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
30+
return tokenizer.batch_decode(output_sequences)

0 commit comments

Comments
 (0)