diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py index ae36266e48..cb66b46954 100644 --- a/swift/llm/template/base.py +++ b/swift/llm/template/base.py @@ -20,6 +20,7 @@ from transformers import StoppingCriteriaList from transformers.integrations import is_deepspeed_zero3_enabled from transformers.utils import strtobool +from transformers import PreTrainedTokenizerBase from swift.utils import get_dist_setting, get_env_args, get_logger, use_torchacc from ..utils import Processor, ProcessorMixin @@ -670,7 +671,9 @@ def _concat_context_list( system: Optional[str] = None, query: Optional[str] = None, response: Optional[str] = None, - round0: Optional[int] = None) -> None: + round0: Optional[int] = None, + tokenizer: Optional[PreTrainedTokenizerBase] = None + ) -> None: """Concat context list and replace placeholder""" round1 = None if round0 is not None: @@ -689,10 +692,14 @@ def _concat_context_list( if new_str is not None and old_str in context: assert isinstance(new_str, str), f'new_str: {new_str}' context = context.replace(old_str, new_str) + res_context_list.append(context) + res_context_type.append(ContextType.OTHER) if len(context) == 0: continue - res_context_list.append(context) - res_context_type.append(ContextType.OTHER) + if isinstance(context, list) and isinstance(context[0], int): + context = tokenizer.decode(context) + res_context_list.append(context) + res_context_type.append(ContextType.OTHER) def _simplify_context_list(self, context_list: List[Context], loss_scale_list: List[float], inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]: @@ -1040,15 +1047,16 @@ def _swift_encode(self, inputs: StdTemplateInputs): bos_token = all_tokens[:idx] sep_token = all_tokens[idx + 1:] if bos_token: - res_context_list.append(bos_token) + res_context_list.append(self.tokenizer.bos_token) + # res_context_list.append(bos_token) res_context_types.append(ContextType.OTHER) if self.template_meta.is_post_system or not system: prefix = template_meta.prefix else: prefix = template_meta.system_prefix - self._concat_context_list(prefix, res_context_list, res_context_types, system=system) - + self._concat_context_list(prefix, res_context_list, res_context_types, system=system, tokenizer=self.tokenizer) + n_round = len(inputs.messages) // 2 for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])): query_role, query = query_message['role'], query_message['content'] diff --git a/swift/trainers/rlhf_trainer/grpo_trainer.py b/swift/trainers/rlhf_trainer/grpo_trainer.py index 4059093b12..fed2e8943d 100644 --- a/swift/trainers/rlhf_trainer/grpo_trainer.py +++ b/swift/trainers/rlhf_trainer/grpo_trainer.py @@ -1094,7 +1094,19 @@ def _apply_chat_template_to_messages_list(self, messages_list: InputsType): InferRequest.remove_response(messages) template_inputs, _ = StdTemplateInputs.from_dict({'messages': messages}) res_context_list, _, _ = self.template._swift_encode(template_inputs) - prompts_text.append(''.join(res_context_list)) + # 类型检查和转换 + processed_context = [] + for context in res_context_list: + if isinstance(context, str): + processed_context.append(context) + elif isinstance(context, list) and all(isinstance(x, int) for x in context): + # 将token ID列表解码为字符串 + decoded_text = self.template.tokenizer.decode(context) + processed_context.append(decoded_text) + else: + # 其他类型,转换为字符串 + processed_context.append(str(context)) + prompts_text.append(''.join(processed_context)) return prompts_text @profiling_decorator