Skip to content

Commit 350b7ea

Browse files
committed
resolve comments
1 parent ccd5edd commit 350b7ea

File tree

5 files changed

+19
-22
lines changed

5 files changed

+19
-22
lines changed

src/wraval/actions/action_generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# // SPDX-License-Identifier: Apache-2.0
44
#
55
import pandas as pd
6-
from .data_utils import write_dataset_local, write_dataset
6+
from .data_utils import write_dataset
77
from dynaconf import Dynaconf
88
from .prompt_tones import get_all_tones, Tone
99
import os

src/wraval/actions/action_inference.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ def run_inference(
1515
data_dir: str
1616
) -> None:
1717
"""Run inference on sentences using the specified model"""
18-
d = load_latest_dataset(data_dir)
18+
results = load_latest_dataset(data_dir)
1919

20-
if "rewrite" not in d.columns:
21-
d["rewrite"] = None
22-
if "inference_model" not in d.columns:
23-
d["inference_model"] = None
20+
if "rewrite" not in results.columns:
21+
results["rewrite"] = None
22+
if "inference_model" not in results.columns:
23+
results["inference_model"] = None
2424

25-
tones = d["tone"].unique()
25+
tones = results["tone"].unique()
2626
print(f"Found tones: {tones}")
2727

2828
if settings.type != "all":
@@ -39,17 +39,17 @@ def run_inference(
3939

4040
tone_prompt = get_prompt(Tone(tone))
4141

42-
queries = d[d["tone"] == tone]["synthetic_data"].unique()
42+
queries = results[results["tone"] == tone]["synthetic_data"].unique()
4343

4444
print(f"Processing {len(queries)} unique inputs for tone: {tone}")
4545

4646
outputs = route_completion(settings, queries, tone_prompt)
4747

4848
cleaned_output = [o.strip().strip('"') for o in outputs]
49-
new = pd.DataFrame({"synthetic_data" : queries, "tone" : tone})
50-
new["rewrite"] = cleaned_output
51-
new["inference_model"] = model_name
49+
new_results = pd.DataFrame({"synthetic_data" : queries, "tone" : tone})
50+
new_results["rewrite"] = cleaned_output
51+
new_results["inference_model"] = model_name
5252

53-
d = pd.concat([d, new], ignore_index=True)
53+
results = pd.concat([results, new_results], ignore_index=True)
5454

55-
write_dataset(d, data_dir, "all", "csv")
55+
write_dataset(results, data_dir, "all", "csv")

src/wraval/actions/completion.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,6 @@ def invoke_sagemaker_endpoint(
196196
Body=input_string.encode("utf-8"),
197197
ContentType="application/json",
198198
)
199-
import pdb; pdb.set_trace()
200199
json_output = response["Body"].readlines()
201200
plain_output = "\n".join(json.loads(json_output[0]))
202201
last_assistant = extract_last_assistant_response(plain_output)

src/wraval/actions/data_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99

1010
def write_dataset(
1111
df: pd.DataFrame, data_dir: str, file_prefix: str, format: str
12-
) -> str:
12+
):
1313
if is_s3_path(data_dir):
1414
bucket, prefix = parse_s3_path(data_dir)
15-
return write_dataset_s3(df, bucket, prefix, file_prefix, format)
15+
write_dataset_s3(df, bucket, prefix, file_prefix, format)
1616
else:
17-
return write_dataset_local(df, data_dir, file_prefix, format)
17+
write_dataset_local(df, data_dir, file_prefix, format)
1818

1919
def write_dataset_s3(
2020
df: pd.DataFrame, bucket: str, prefix: str, file_prefix: str, format: str
21-
) -> str:
21+
):
2222
with tempfile.TemporaryDirectory() as temp_dir:
2323
temp_file = os.path.join(temp_dir, "temp.csv")
2424
df.to_csv(temp_file, index=False)
@@ -28,7 +28,6 @@ def write_dataset_s3(
2828
)
2929
print(f"Writing dataset to s3://{bucket}/{key}")
3030
s3_client.upload_file(temp_file, bucket, key)
31-
return f"s3://{bucket}/{key}"
3231

3332
def write_dataset_local(
3433
df: pd.DataFrame, data_dir: str, file_prefix: str, format: str
@@ -42,8 +41,7 @@ def write_dataset_local(
4241
add_timestamp_to_file_prefix(file_prefix, format)
4342
)
4443
df.to_csv(output_path, index=False)
45-
print(f"Saved to {output_path}")
46-
return output_path
44+
print(f"Saved locally to {output_path}")
4745

4846

4947
def add_timestamp_to_file_prefix(file_prefix, format):

src/wraval/model_artifacts/code/inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,4 @@ def predict_fn(data, model_and_tokenizer, *args):
2727
model.config.pad_token_id = model.config.eos_token_id
2828
inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(model.device)
2929
output_sequences = model.generate(**inputs, max_new_tokens=1024)
30-
return tokenizer.batch_decode(output_sequences, skip_special_tokens=True)
30+
return tokenizer.batch_decode(output_sequences)

0 commit comments

Comments
 (0)