Skip to content

Aacedar patch 3 #4832

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from transformers import StoppingCriteriaList
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import strtobool
from transformers import PreTrainedTokenizerBase

from swift.utils import get_dist_setting, get_env_args, get_logger, use_torchacc
from ..utils import Processor, ProcessorMixin
Expand Down Expand Up @@ -670,7 +671,9 @@ def _concat_context_list(
system: Optional[str] = None,
query: Optional[str] = None,
response: Optional[str] = None,
round0: Optional[int] = None) -> None:
round0: Optional[int] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
"""Concat context list and replace placeholder"""
round1 = None
if round0 is not None:
Expand All @@ -689,10 +692,14 @@ def _concat_context_list(
if new_str is not None and old_str in context:
assert isinstance(new_str, str), f'new_str: {new_str}'
context = context.replace(old_str, new_str)
res_context_list.append(context)
res_context_type.append(ContextType.OTHER)
if len(context) == 0:
continue
res_context_list.append(context)
res_context_type.append(ContextType.OTHER)
if isinstance(context, list) and isinstance(context[0], int):
context = tokenizer.decode(context)
res_context_list.append(context)
res_context_type.append(ContextType.OTHER)

def _simplify_context_list(self, context_list: List[Context], loss_scale_list: List[float],
inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
Expand Down Expand Up @@ -1040,15 +1047,16 @@ def _swift_encode(self, inputs: StdTemplateInputs):
bos_token = all_tokens[:idx]
sep_token = all_tokens[idx + 1:]
if bos_token:
res_context_list.append(bos_token)
res_context_list.append(self.tokenizer.bos_token)
# res_context_list.append(bos_token)
res_context_types.append(ContextType.OTHER)

if self.template_meta.is_post_system or not system:
prefix = template_meta.prefix
else:
prefix = template_meta.system_prefix
self._concat_context_list(prefix, res_context_list, res_context_types, system=system)

self._concat_context_list(prefix, res_context_list, res_context_types, system=system, tokenizer=self.tokenizer)
n_round = len(inputs.messages) // 2
for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])):
query_role, query = query_message['role'], query_message['content']
Expand Down
14 changes: 13 additions & 1 deletion swift/trainers/rlhf_trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,7 +1094,19 @@ def _apply_chat_template_to_messages_list(self, messages_list: InputsType):
InferRequest.remove_response(messages)
template_inputs, _ = StdTemplateInputs.from_dict({'messages': messages})
res_context_list, _, _ = self.template._swift_encode(template_inputs)
prompts_text.append(''.join(res_context_list))
# 类型检查和转换
processed_context = []
for context in res_context_list:
if isinstance(context, str):
processed_context.append(context)
elif isinstance(context, list) and all(isinstance(x, int) for x in context):
# 将token ID列表解码为字符串
decoded_text = self.template.tokenizer.decode(context)
processed_context.append(decoded_text)
else:
# 其他类型,转换为字符串
processed_context.append(str(context))
prompts_text.append(''.join(processed_context))
return prompts_text

@profiling_decorator
Expand Down