@@ -71,10 +71,11 @@ def main():
7171 for i , data in enumerate (cache ):
7272 message , response = data [0 ], data [1 ]
7373 pred = response ['choices' ][0 ]['message' ]['content' ]
74- pred = pred .lower ()
75- pre = pred .split ("<" )[0 ].strip ()
76- pre = pre .split ("." )[0 ].strip ()
77- pre = pre .split ("\n " )[0 ].strip ()
74+ pre = pred .lower ()
75+ pre = pre .split ("<" )[0 ].strip () if "<" in pre else pre
76+ pre = pre .split ("." )[0 ].strip () if "." in pre else pre
77+ pre = pre .split ("\n " )[0 ].strip () if "\n " in pre else pre
78+ pre = pre .split ("'" )[1 ].strip () if "'" in pre else pre
7879
7980 total += 1
8081 if pre in option :
@@ -132,9 +133,11 @@ def main():
132133 for i , data in enumerate (cache ):
133134 message , response = data [0 ], data [1 ]
134135 pred = response ['choices' ][0 ]['message' ]['content' ]
135- pred = pred .lower ()
136- pre = pred .split ("</s>" )[0 ].strip ()
137- pre = pre .split ("." )[0 ].strip ()
136+ pre = pred .lower ()
137+ pre = pre .split ("</s>" )[0 ].strip () if "</s>" in pre else pre
138+ pre = pre .split ("." )[0 ].strip () if "." in pre else pre
139+ pre = pre .split ("\n " )[0 ].strip () if "\n " in pre else pre
140+ pre = pre .split ("'" )[1 ].strip () if "'" in pre else pre
138141
139142 cnt += 1
140143 if pre in option :
@@ -177,9 +180,11 @@ def main():
177180 for i , data in enumerate (cache ):
178181 message , response = data [0 ], data [1 ]
179182 pred = response ['choices' ][0 ]['message' ]['content' ]
180- pred = pred .lower ()
181- pre = pred .split ("</s>" )[0 ].strip ()
182- pre = pre .split ("." )[0 ].strip ()
183+ pre = pred .lower ()
184+ pre = pre .split ("</s>" )[0 ].strip () if "</s>" in pre else pre
185+ pre = pre .split ("." )[0 ].strip () if "." in pre else pre
186+ pre = pre .split ("\n " )[0 ].strip () if "\n " in pre else pre
187+ pre = pre .split ("'" )[1 ].strip () if "'" in pre else pre
183188
184189 cnt += 1
185190 if pre in option :
@@ -228,11 +233,11 @@ def main():
228233 for i , data in enumerate (cache ):
229234 message , response = data [0 ], data [1 ]
230235 pred = response ['choices' ][0 ]['message' ]['content' ]
231- pred = pred .lower ()
232- # pre = pred
233- pre = pred .split ("< " )[0 ].strip ()
234- pre = pre .split (". " )[0 ].strip ()
235- pre = pre .split ("\n " )[0 ].strip ()
236+ pre = pred .lower ()
237+ pre = pre . split ( "<" )[ 0 ]. strip () if "<" in pre else pre
238+ pre = pre .split (". " )[0 ].strip () if "." in pre else pre
239+ pre = pre .split ("\n " )[0 ].strip () if " \n " in pre else pre
240+ pre = pre .split ("' " )[1 ].strip () if "'" in pre else pre
236241
237242 total += 1
238243 if pre in option :
@@ -310,6 +315,7 @@ def main():
310315 # Save results
311316 save_path = os .path .join (base_dir , "scores.jsonl" )
312317 with open (save_path , "w" ) as file :
318+ print ("Saving..." )
313319 for item in result_list :
314320 json_str = json .dumps (item )
315321 file .write (json_str + "\n " )
0 commit comments