Skip to content

Commit b059291

Browse files
authored
[grpo] update apply template (#4833)
* Update bos_token bug * Update grpo_trainer.py for special token encode error for deepseek model * Update grpo_trainer.py for special token encode error for deepseek model * Update grpo_trainer.py
1 parent e9f9e08 commit b059291

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1094,8 +1094,20 @@ def _apply_chat_template_to_messages_list(self, messages_list: InputsType):
10941094
InferRequest.remove_response(messages)
10951095
template_inputs, _ = StdTemplateInputs.from_dict({'messages': messages})
10961096
res_context_list, _, _ = self.template._swift_encode(template_inputs)
1097-
prompts_text.append(''.join(elem for elem in res_context_list if isinstance(elem, str)))
10981097

1098+
# check the type and convert
1099+
processed_context = []
1100+
for context in res_context_list:
1101+
if isinstance(context, str):
1102+
processed_context.append(context)
1103+
elif isinstance(context, list) and all(isinstance(x, int) for x in context):
1104+
# decode the token ID to text
1105+
decoded_text = self.template.tokenizer.decode(context)
1106+
processed_context.append(decoded_text)
1107+
else:
1108+
# other type value ,just add to process_context
1109+
processed_context.append(str(context))
1110+
prompts_text.append(''.join(processed_context))
10991111
return prompts_text
11001112

11011113
@profiling_decorator

0 commit comments

Comments
 (0)