@@ -147,7 +147,6 @@ def tokenize_prompt(
147
147
ignore_index: the ignore index when calculate loss during training
148
148
max_length: the maximum context length
149
149
"""
150
-
151
150
messages = data_point ["messages" ]
152
151
template = deepcopy (conversation_template )
153
152
template .messages = []
@@ -167,7 +166,6 @@ def tokenize_prompt(
167
166
if len (template .messages ) % 2 != 1 :
168
167
# exclude the answer if provided. keep only the prompt
169
168
template .messages = template .messages [:- 1 ]
170
-
171
169
# Prepare data
172
170
prompt = template .get_prompt (length = len (template .messages ), add_generation_prompt = True )
173
171
tokenized = tokenizer ([prompt ], add_special_tokens = False )["input_ids" ][0 ]
@@ -185,12 +183,21 @@ def tokenize_prompt(
185
183
)
186
184
187
185
# `inputs_decode` can be used to check whether the tokenization method is true.
188
- return dict (
189
- input_ids = tokenized ,
190
- inputs_decode = prompt ,
191
- seq_length = len (tokenized ),
192
- seq_category = data_point ["category" ] if "category" in data_point else "None" ,
193
- )
186
+ if "gt_answer" in data_point :
187
+ return dict (
188
+ input_ids = tokenized ,
189
+ inputs_decode = prompt ,
190
+ seq_length = len (tokenized ),
191
+ seq_category = data_point ["category" ] if "category" in data_point else "None" ,
192
+ gt_answer = data_point ["gt_answer" ],
193
+ )
194
+ else :
195
+ return dict (
196
+ input_ids = tokenized ,
197
+ inputs_decode = prompt ,
198
+ seq_length = len (tokenized ),
199
+ seq_category = data_point ["category" ] if "category" in data_point else "None" ,
200
+ )
194
201
195
202
196
203
def apply_rlhf_data_format (template : Conversation , tokenizer : Any ):
0 commit comments