@@ -39,6 +39,7 @@ def validate_dataset(d: pd.DataFrame) -> bool:
3939 return True
4040
4141def process_tone_data (
42+ settings : Dynaconf ,
4243 d : pd .DataFrame ,
4344 tone : str ,
4445 model_name : str ,
@@ -61,22 +62,21 @@ def process_tone_data(
6162 rubrics = list (tone_rubrics .keys ())
6263
6364 # Generate prompts
64- prompts = []
65+ user_prompts = []
66+ sys_prompts = []
67+
6568 for q , a in zip (dmt ["synthetic_data" ], dmt ["rewrite" ]):
6669 for rubric in rubrics :
67- prompts .append ((
68- generate_system_prompt (tone_rubrics [rubric ]),
69- generate_input_prompt (q , a , tone )
70- ))
70+ user_prompts .append (generate_input_prompt (q , a , tone ))
71+ sys_prompts .append (generate_system_prompt (tone_rubrics [rubric ]))
7172
7273 # Get completions
73- sys_prompts , user_prompts = zip (* prompts )
74+ # import pdb
75+ # pdb.set_trace()
7476 completions = batch_get_bedrock_completions (
75- model_name ,
76- client ,
77+ settings ,
7778 user_prompts ,
78- sys_prompts ,
79- max_concurrent = len (user_prompts )
79+ sys_prompts
8080 )
8181
8282 rubrics = [r .lower () for r in rubrics ]
@@ -99,7 +99,6 @@ def judge(
9999 client : boto3 .client ,
100100 model_name : str ,
101101 upload_s3 : bool ,
102- data_dir : str ,
103102 endpoint_type : str = "bedrock"
104103) -> None :
105104 """Judge rewrites using specified model and rubrics.
@@ -113,7 +112,7 @@ def judge(
113112 endpoint_type: Type of endpoint to use
114113 """
115114 try :
116- d = load_latest_dataset (data_dir )
115+ d = load_latest_dataset (settings . data_dir )
117116 print (f"Loaded dataset with { len (d )} rows" )
118117 except FileNotFoundError :
119118 print ("No dataset found. Please generate data first." )
@@ -129,7 +128,7 @@ def judge(
129128 print (f"\n { '=' * 20 } \n { tone } \n { '=' * 20 } " )
130129
131130 tone_rubrics = get_rubric (tone .upper ())
132- dmt = process_tone_data (d , tone , model_name , client , tone_rubrics )
131+ dmt = process_tone_data (settings , d , tone , model_name , client , tone_rubrics )
133132
134133 # Update main dataframe
135134 mask = (d .tone == tone )
0 commit comments