@@ -142,6 +142,7 @@ def infer_cli(self) -> List[Dict[str, Any]]:
142142 data = infer_state .to_dict ()
143143 response = self .infer_single (data , request_config )
144144 infer_state .add_response (response )
145+ data ['messages' ].append ({'role' : 'assistant' , 'content' : response })
145146 data = {'response' : response , ** data }
146147 result_list .append (data )
147148 if self .jsonl_writer :
@@ -196,6 +197,7 @@ def infer_dataset(self) -> List[Dict[str, Any]]:
196197 print (f'[LABELS] { labels } ' )
197198 print ('[RESPONSE] ' , end = '' )
198199 response = self .infer_single (data , request_config )
200+ data ['messages' ].append ({'role' : 'assistant' , 'content' : response })
199201 data = {'response' : response , 'labels' : labels , ** data }
200202 result_list .append (data )
201203 if self .jsonl_writer :
@@ -218,6 +220,7 @@ def infer_dataset(self) -> List[Dict[str, Any]]:
218220 val_dataset , request_config , template = self .template , use_tqdm = True , ** self .infer_kwargs )
219221 for data , resp , labels in zip (val_dataset , resp_list , labels_list ):
220222 response = resp .choices [0 ].message .content
223+ data ['messages' ].append ({'role' : 'assistant' , 'content' : response })
221224 data = {'response' : response , 'labels' : labels , 'logprobs' : resp .choices [0 ].logprobs , ** data }
222225 result_list .append (data )
223226 if self .jsonl_writer :
0 commit comments