22# // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
33# // SPDX-License-Identifier: Apache-2.0
44#
5+ import pandas as pd
56from dynaconf import Dynaconf
6- from .data_utils import write_dataset_local , write_dataset_to_s3 , load_latest_dataset
7+ from .data_utils import write_dataset , load_latest_dataset
78from .prompt_tones import get_prompt , Tone
89from .model_router import route_completion
910
@@ -14,19 +15,14 @@ def run_inference(
1415 data_dir : str
1516) -> None :
1617 """Run inference on sentences using the specified model"""
17- try :
18- d = load_latest_dataset (data_dir )
19- print (f"Loaded dataset with { len (d )} rows" )
20- except FileNotFoundError :
21- print ("No dataset found. Please generate data first." )
22- return
23-
24- if "rewrite" not in d .columns :
25- d ["rewrite" ] = None
26- if "inference_model" not in d .columns :
27- d ["inference_model" ] = None
28-
29- tones = d ["tone" ].unique ()
18+ results = load_latest_dataset (data_dir )
19+
20+ if "rewrite" not in results .columns :
21+ results ["rewrite" ] = None
22+ if "inference_model" not in results .columns :
23+ results ["inference_model" ] = None
24+
25+ tones = results ["tone" ].unique ()
3026 print (f"Found tones: { tones } " )
3127
3228 if settings .type != "all" :
@@ -43,18 +39,17 @@ def run_inference(
4339
4440 tone_prompt = get_prompt (Tone (tone ))
4541
46- queries = d [ d ["tone" ] == tone ]["synthetic_data" ].unique ()
42+ queries = results [ results ["tone" ] == tone ]["synthetic_data" ].unique ()
4743
4844 print (f"Processing { len (queries )} unique inputs for tone: { tone } " )
4945
5046 outputs = route_completion (settings , queries , tone_prompt )
5147
52- for query , output in zip (queries , outputs ):
53- mask = (d ["synthetic_data" ] == query ) & (d ["tone" ] == tone )
54- cleaned_output = output .strip ().strip ('"' )
55- d .loc [mask , "rewrite" ] = cleaned_output
56- d .loc [mask , "inference_model" ] = model_name
48+ cleaned_output = [o .strip ().strip ('"' ) for o in outputs ]
49+ new_results = pd .DataFrame ({"synthetic_data" : queries , "tone" : tone })
50+ new_results ["rewrite" ] = cleaned_output
51+ new_results ["inference_model" ] = model_name
52+
53+ results = pd .concat ([results , new_results ], ignore_index = True )
5754
58- write_dataset_local (d , "./data" , "all-tones" )
59- if upload_s3 :
60- write_dataset_to_s3 (d , settings .s3_bucket , "inference/all" , "csv" )
55+ write_dataset (results , data_dir , "all" , "csv" )
0 commit comments