Skip to content

Commit ccd5edd

Browse files
committed
write_dataset that calls either s3 or local
1 parent b61c18b commit ccd5edd

File tree

8 files changed

+41
-36
lines changed

8 files changed

+41
-36
lines changed

config/settings.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[default]
22
region = 'us-east-1'
3-
data_dir = 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
3+
data_dir = "./data"
4+
# 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
45
deploy_bucket_name = 'llm-finetune-us-east-1-{aws_account}'
56
deploy_bucket_prefix = 'models'
67
sagemaker_execution_role_arn = 'arn:aws:iam::{aws_account}:role/sagemaker-execution-role-us-east-1'

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,4 @@ where = ["src"]
4646

4747
[tool.setuptools.data-files]
4848
# This copies the config file into the installation (non-package dir)
49-
"config" = ["config/settings.toml"]
50-
49+
"config" = ["config/settings.toml"]

src/wraval/actions/action_generate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# // SPDX-License-Identifier: Apache-2.0
44
#
55
import pandas as pd
6-
from .data_utils import write_dataset_local, write_dataset_to_s3
6+
from .data_utils import write_dataset_local, write_dataset
77
from dynaconf import Dynaconf
88
from .prompt_tones import get_all_tones, Tone
99
import os
@@ -79,6 +79,4 @@ def generate_tone_data(
7979

8080
combined = pd.concat(datasets, ignore_index=True)
8181

82-
write_dataset_local(combined, settings.data_dir, "all-tones")
83-
if upload_s3:
84-
write_dataset_to_s3(combined, settings.s3_bucket, "generate/all", "csv")
82+
write_dataset(combined, settings.data_dir, "all", "csv")

src/wraval/actions/action_inference.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
# // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
33
# // SPDX-License-Identifier: Apache-2.0
44
#
5+
import pandas as pd
56
from dynaconf import Dynaconf
6-
from .data_utils import write_dataset_local, write_dataset_to_s3, load_latest_dataset
7+
from .data_utils import write_dataset, load_latest_dataset
78
from .prompt_tones import get_prompt, Tone
89
from .model_router import route_completion
910

@@ -14,12 +15,7 @@ def run_inference(
1415
data_dir: str
1516
) -> None:
1617
"""Run inference on sentences using the specified model"""
17-
try:
18-
d = load_latest_dataset(data_dir)
19-
print(f"Loaded dataset with {len(d)} rows")
20-
except FileNotFoundError:
21-
print("No dataset found. Please generate data first.")
22-
return
18+
d = load_latest_dataset(data_dir)
2319

2420
if "rewrite" not in d.columns:
2521
d["rewrite"] = None
@@ -49,12 +45,11 @@ def run_inference(
4945

5046
outputs = route_completion(settings, queries, tone_prompt)
5147

52-
for query, output in zip(queries, outputs):
53-
mask = (d["synthetic_data"] == query) & (d["tone"] == tone)
54-
cleaned_output = output.strip().strip('"')
55-
d.loc[mask, "rewrite"] = cleaned_output
56-
d.loc[mask, "inference_model"] = model_name
48+
cleaned_output = [o.strip().strip('"') for o in outputs]
49+
new = pd.DataFrame({"synthetic_data" : queries, "tone" : tone})
50+
new["rewrite"] = cleaned_output
51+
new["inference_model"] = model_name
5752

58-
write_dataset_local(d, "./data", "all-tones")
59-
if upload_s3:
60-
write_dataset_to_s3(d, settings.s3_bucket, "inference/all", "csv")
53+
d = pd.concat([d, new], ignore_index=True)
54+
55+
write_dataset(d, data_dir, "all", "csv")

src/wraval/actions/action_llm_judge.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pandas as pd
66
from typing import List, Dict, Optional
77
from dynaconf import Dynaconf
8-
from .data_utils import write_dataset_local, write_dataset_to_s3, load_latest_dataset
8+
from .data_utils import write_dataset, load_latest_dataset
99
from .prompts_judge import generate_input_prompt, generate_system_prompt, get_rubric, rewrite_prompt
1010

1111
from .completion import batch_get_bedrock_completions
@@ -144,9 +144,7 @@ def judge(
144144
d.loc[mask, dmt.columns] = dmt.values
145145

146146
# Save results
147-
write_dataset_local(d, "./data", "all-tones")
148-
if upload_s3:
149-
write_dataset_to_s3(d, settings.s3_bucket, "inference/all", "csv")
147+
write_dataset(d, settings.data_dir, "all-tones", "csv")
150148

151149
def rewrite_judge(
152150
model_id: str,

src/wraval/actions/completion.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ def invoke_sagemaker_endpoint(
196196
Body=input_string.encode("utf-8"),
197197
ContentType="application/json",
198198
)
199+
import pdb; pdb.set_trace()
199200
json_output = response["Body"].readlines()
200201
plain_output = "\n".join(json.loads(json_output[0]))
201202
last_assistant = extract_last_assistant_response(plain_output)

src/wraval/actions/data_utils.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,26 +7,39 @@
77
from urllib.parse import urlparse
88

99

10-
def write_dataset_to_s3(
11-
df: pd.DataFrame, bucket: str, key_prefix: str, format: str
10+
def write_dataset(
11+
df: pd.DataFrame, data_dir: str, file_prefix: str, format: str
12+
) -> str:
13+
if is_s3_path(data_dir):
14+
bucket, prefix = parse_s3_path(data_dir)
15+
return write_dataset_s3(df, bucket, prefix, file_prefix, format)
16+
else:
17+
return write_dataset_local(df, data_dir, file_prefix, format)
18+
19+
def write_dataset_s3(
20+
df: pd.DataFrame, bucket: str, prefix: str, file_prefix: str, format: str
1221
) -> str:
1322
with tempfile.TemporaryDirectory() as temp_dir:
14-
temp_file = os.path.join(temp_dir, "temp.jsonl")
15-
df.to_json(temp_file, orient="records", lines=bool(format == "jsonl"))
23+
temp_file = os.path.join(temp_dir, "temp.csv")
24+
df.to_csv(temp_file, index=False)
1625
s3_client = boto3.client("s3")
17-
key = add_timestamp_to_file_prefix(key_prefix, format)
18-
print(f"Writing dataset to bucket {bucket} and key {key}.")
26+
key = os.path.join(prefix,
27+
add_timestamp_to_file_prefix(file_prefix, format)
28+
)
29+
print(f"Writing dataset to s3://{bucket}/{key}")
1930
s3_client.upload_file(temp_file, bucket, key)
2031
return f"s3://{bucket}/{key}"
2132

22-
23-
def write_dataset_local(df: pd.DataFrame, data_dir: str, file_prefix: str) -> str:
33+
def write_dataset_local(
34+
df: pd.DataFrame, data_dir: str, file_prefix: str, format: str
35+
) -> str:
2436
# Expand home directory and create if needed
2537
data_dir = os.path.expanduser(data_dir)
2638
os.makedirs(data_dir, exist_ok=True)
2739

2840
output_path = os.path.join(
29-
data_dir, add_timestamp_to_file_prefix(file_prefix, "csv")
41+
data_dir,
42+
add_timestamp_to_file_prefix(file_prefix, format)
3043
)
3144
df.to_csv(output_path, index=False)
3245
print(f"Saved to {output_path}")

src/wraval/actions/model_router.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
)
66
from .format import format_prompt
77
from transformers import AutoTokenizer
8-
from tqdm import tqdm
8+
from tqdm.auto import tqdm
99
from typing import List
1010
from dynaconf.base import LazySettings
1111
from abc import ABC, abstractmethod

0 commit comments

Comments
 (0)