Skip to content

Update template_meta.prefix bug #4813

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from transformers import StoppingCriteriaList
from transformers.integrations import is_deepspeed_zero3_enabled
from transformers.utils import strtobool
from transformers import PreTrainedTokenizerBase

from swift.utils import get_dist_setting, get_env_args, get_logger, use_torchacc
from ..utils import Processor, ProcessorMixin
Expand Down Expand Up @@ -670,7 +671,9 @@ def _concat_context_list(
system: Optional[str] = None,
query: Optional[str] = None,
response: Optional[str] = None,
round0: Optional[int] = None) -> None:
round0: Optional[int] = None,
tokenizer: Optional[PreTrainedTokenizerBase] = None
) -> None:
"""Concat context list and replace placeholder"""
round1 = None
if round0 is not None:
Expand All @@ -689,10 +692,14 @@ def _concat_context_list(
if new_str is not None and old_str in context:
assert isinstance(new_str, str), f'new_str: {new_str}'
context = context.replace(old_str, new_str)
res_context_list.append(context)
res_context_type.append(ContextType.OTHER)
if len(context) == 0:
continue
res_context_list.append(context)
res_context_type.append(ContextType.OTHER)
if isinstance(context, list) and isinstance(context[0], int):
context = tokenizer.decode(context)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decode is not valid here.
The special tokens of a few models have no valid string outputs, or the outputs may be empty

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The res_context_list is a mix of strings and list of integers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

“The res_context_list is a mix of strings and list of integers” has two quesion:

  1. if the elment of res_context_list is list of Integers, the code prompts_text.append(''.join(res_context_list)) will raise an exception.
  2. if the elment of res_context_list is list of Integers, then use prompts_text.append(''.join(res_context_list)) merge to prompt that prompt type is str, you will get an exception when execute encode or you will get uncorrent result when encode because Integers is already token_id, not str.

Copy link
Contributor Author

@aacedar aacedar Jul 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decode is not valid here. The special tokens of a few models have no valid string outputs, or the outputs may be empty

some models has no valid string outputs, that means the prompt need not special token, it is also ok.
maybe wo can modify the code to

if isinstance(context, list) and isinstance(context[0], int):
    context = tokenizer.decode(context)
    if context:
        res_context_list.append(context)
        res_context_type.append(ContextType.OTHER)

res_context_list.append(context)
res_context_type.append(ContextType.OTHER)

def _simplify_context_list(self, context_list: List[Context], loss_scale_list: List[float],
inputs: StdTemplateInputs) -> Tuple[List[Context], List[float]]:
Expand Down Expand Up @@ -1040,15 +1047,16 @@ def _swift_encode(self, inputs: StdTemplateInputs):
bos_token = all_tokens[:idx]
sep_token = all_tokens[idx + 1:]
if bos_token:
res_context_list.append(bos_token)
res_context_list.append(self.tokenizer.bos_token)
# res_context_list.append(bos_token)
res_context_types.append(ContextType.OTHER)

if self.template_meta.is_post_system or not system:
prefix = template_meta.prefix
else:
prefix = template_meta.system_prefix
self._concat_context_list(prefix, res_context_list, res_context_types, system=system)

self._concat_context_list(prefix, res_context_list, res_context_types, system=system, tokenizer=self.tokenizer)
n_round = len(inputs.messages) // 2
for i, (query_message, response_message) in enumerate(zip(inputs.messages[::2], inputs.messages[1::2])):
query_role, query = query_message['role'], query_message['content']
Expand Down