@@ -15,14 +15,14 @@ def run_inference(
1515 data_dir : str
1616) -> None :
1717 """Run inference on sentences using the specified model"""
18- d = load_latest_dataset (data_dir )
18+ results = load_latest_dataset (data_dir )
1919
20- if "rewrite" not in d .columns :
21- d ["rewrite" ] = None
22- if "inference_model" not in d .columns :
23- d ["inference_model" ] = None
20+ if "rewrite" not in results .columns :
21+ results ["rewrite" ] = None
22+ if "inference_model" not in results .columns :
23+ results ["inference_model" ] = None
2424
25- tones = d ["tone" ].unique ()
25+ tones = results ["tone" ].unique ()
2626 print (f"Found tones: { tones } " )
2727
2828 if settings .type != "all" :
@@ -39,17 +39,17 @@ def run_inference(
3939
4040 tone_prompt = get_prompt (Tone (tone ))
4141
42- queries = d [ d ["tone" ] == tone ]["synthetic_data" ].unique ()
42+ queries = results [ results ["tone" ] == tone ]["synthetic_data" ].unique ()
4343
4444 print (f"Processing { len (queries )} unique inputs for tone: { tone } " )
4545
4646 outputs = route_completion (settings , queries , tone_prompt )
4747
4848 cleaned_output = [o .strip ().strip ('"' ) for o in outputs ]
49- new = pd .DataFrame ({"synthetic_data" : queries , "tone" : tone })
50- new ["rewrite" ] = cleaned_output
51- new ["inference_model" ] = model_name
49+ new_results = pd .DataFrame ({"synthetic_data" : queries , "tone" : tone })
50+ new_results ["rewrite" ] = cleaned_output
51+ new_results ["inference_model" ] = model_name
5252
53- d = pd .concat ([d , new ], ignore_index = True )
53+ results = pd .concat ([results , new_results ], ignore_index = True )
5454
55- write_dataset (d , data_dir , "all" , "csv" )
55+ write_dataset (results , data_dir , "all" , "csv" )
0 commit comments