Skip to content

Commit fa3d2d6

Browse files
authored
[grpo] fix apply template to tool call dataset (#5471)
* fix apply template * use safe_decode
1 parent eb6ddad commit fa3d2d6

File tree

1 file changed

+2
-15
lines changed

1 file changed

+2
-15
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,21 +1230,8 @@ def _apply_chat_template_to_messages_list(self, messages_list: InputsType):
12301230
for messages in messages_list:
12311231
InferRequest.remove_response(messages)
12321232
template_inputs, _ = StdTemplateInputs.from_dict({'messages': messages})
1233-
res_context_list, _, _ = self.template._swift_encode(template_inputs)
1234-
1235-
# check the type and convert
1236-
processed_context = []
1237-
for context in res_context_list:
1238-
if isinstance(context, str):
1239-
processed_context.append(context)
1240-
elif isinstance(context, list) and all(isinstance(x, int) for x in context):
1241-
# decode the token ID to text
1242-
decoded_text = self.template.tokenizer.decode(context)
1243-
processed_context.append(decoded_text)
1244-
else:
1245-
# other type value ,just add to process_context
1246-
processed_context.append(str(context))
1247-
prompts_text.append(''.join(processed_context))
1233+
res = self.template.encode(template_inputs)
1234+
prompts_text.append(self.template.safe_decode(res['input_ids']))
12481235
return prompts_text
12491236

12501237
@patch_profiling_decorator

0 commit comments

Comments
 (0)