Skip to content

Commit 5cc33a0

Browse files
hjh0119Jintao-Huang
authored andcommitted
[grpo] fix apply template to tool call dataset (#5471)
* fix apply template * use safe_decode
1 parent afd5319 commit 5cc33a0

File tree

1 file changed

+2
-15
lines changed

1 file changed

+2
-15
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,21 +1217,8 @@ def _apply_chat_template_to_messages_list(self, messages_list: InputsType):
12171217
for messages in messages_list:
12181218
InferRequest.remove_response(messages)
12191219
template_inputs, _ = StdTemplateInputs.from_dict({'messages': messages})
1220-
res_context_list, _, _ = self.template._swift_encode(template_inputs)
1221-
1222-
# check the type and convert
1223-
processed_context = []
1224-
for context in res_context_list:
1225-
if isinstance(context, str):
1226-
processed_context.append(context)
1227-
elif isinstance(context, list) and all(isinstance(x, int) for x in context):
1228-
# decode the token ID to text
1229-
decoded_text = self.template.tokenizer.decode(context)
1230-
processed_context.append(decoded_text)
1231-
else:
1232-
# other type value ,just add to process_context
1233-
processed_context.append(str(context))
1234-
prompts_text.append(''.join(processed_context))
1220+
res = self.template.encode(template_inputs)
1221+
prompts_text.append(self.template.safe_decode(res['input_ids']))
12351222
return prompts_text
12361223

12371224
@patch_profiling_decorator

0 commit comments

Comments
 (0)