@@ -77,8 +77,7 @@ def loading_data(
7777 )
7878 answer = "true" if data_point ["answer" ] else "false"
7979 if is_train :
80- prompt += f" { answer } "
81- labels = None
80+ labels = answer
8281 else :
8382 labels = [self .labels2id_ [answer ]]
8483 ret .append (InputData (inputs = prompt , labels = labels ))
@@ -110,8 +109,7 @@ def loading_data(
110109 prompt += f" ({ label } ) { text } "
111110 prompt += "\n Answer:"
112111 if is_train :
113- prompt += " " + data_point ["answerKey" ]
114- labels = None
112+ labels = data_point ["answerKey" ]
115113 else :
116114 labels = [self .labels2id_ [data_point ["answerKey" ]]]
117115 ret .append (InputData (inputs = prompt , labels = labels ))
@@ -138,8 +136,7 @@ def loading_data(
138136 prompt += "\n Correct solution:"
139137 answer = self .labels_ [data_point ["label" ]]
140138 if is_train :
141- prompt += f" { answer } "
142- labels = None
139+ labels = answer
143140 else :
144141 labels = [data_point ["label" ]]
145142 ret .append (InputData (inputs = prompt , labels = labels ))
@@ -168,8 +165,7 @@ def loading_data(
168165 prompt += "\n Answer:"
169166 label = int (data_point ["label" ]) - 1
170167 if is_train :
171- prompt += f" { self .labels_ [label ]} "
172- labels = None
168+ labels = self .labels_ [label ]
173169 else :
174170 labels = [label ]
175171 ret .append (InputData (inputs = prompt , labels = labels ))
@@ -200,8 +196,7 @@ def loading_data(
200196 prompt += "\n Answer:"
201197 label = int (data_point ["label" ])
202198 if is_train :
203- prompt += f" { self .labels_ [label ]} "
204- labels = None
199+ labels = self .labels_ [label ]
205200 else :
206201 labels = [label ]
207202 ret .append (InputData (inputs = prompt , labels = labels ))
@@ -230,8 +225,7 @@ def loading_data(
230225 prompt += "\n Answer:"
231226 label = int (data_point ["answer" ]) - 1
232227 if is_train :
233- prompt += f" { self .labels_ [label ]} "
234- labels = None
228+ labels = self .labels_ [label ]
235229 else :
236230 labels = [label ]
237231 ret .append (InputData (inputs = prompt , labels = labels ))
@@ -261,8 +255,7 @@ def loading_data(
261255 prompt += f" ({ label } ) { text } "
262256 prompt += "\n Answer:"
263257 if is_train :
264- prompt += " " + data_point ["answerKey" ]
265- labels = None
258+ labels = data_point ["answerKey" ]
266259 else :
267260 labels = [self .labels2id_ [data_point ["answerKey" ]]]
268261 ret .append (InputData (inputs = prompt , labels = labels ))
@@ -295,13 +288,10 @@ def loading_data(
295288 prompt += f"({ label } ) { text } \n "
296289 answer = data_point ["final_decision" ]
297290 assert answer in self .labels2id_
291+ prompt += "Answer:"
298292 if is_train :
299- prompt += f"Long Answer:\n { data_point ['long_answer' ]} \n "
300- prompt += "Answer:"
301- prompt += f" { answer } "
302- labels = None
293+ labels = answer
303294 else :
304- prompt += "Answer:"
305295 labels = [self .labels2id_ [answer ]]
306296 ret .append (InputData (inputs = prompt , labels = labels ))
307297
0 commit comments