Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions swift/llm/template/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,8 @@ def _jinja_encode(self, inputs: StdTemplateInputs):
kwargs = {}
if inputs.tools:
kwargs['tools'] = inputs.tools
if 'thinking_budget' in inputs.extra_kwargs:
kwargs['thinking_budget'] = inputs.extra_kwargs.get('thinking_budget', 0)
text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=add_generation_prompt, **kwargs)
answer_len = 1 if self.is_training else 0
Expand Down
5 changes: 2 additions & 3 deletions swift/llm/template/template/seed.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def insert_budget_markers(text: str, tokenizer, interval: int, total_budget: int
return '\n'.join(result)
else:
return ('<seed:cot_budget_reflect>The current thinking budget is 0, so I will '
'directly start answering the question.</seed:cot_budget_reflect>\n\n')
'directly start answering the question.</seed:cot_budget_reflect>\n')

def _prepare_system(self, inputs):
budget = self.get_thinking_budget(inputs)
Expand Down Expand Up @@ -143,7 +143,7 @@ def _swift_prepare_inputs(self, inputs: StdTemplateInputs):
message['content'] = (
'<seed:think><seed:cot_budget_reflect>The current thinking budget is 0, '
'so I will directly start answering the question.'
'</seed:cot_budget_reflect>\n\n</seed:think>') + message['content']
'</seed:cot_budget_reflect>\n</seed:think>') + message['content']

def _simplify_context_list(self, context_list, loss_scale_list, inputs):
res, res_loss_scale = super()._simplify_context_list(context_list, loss_scale_list, inputs)
Expand All @@ -154,7 +154,6 @@ def _simplify_context_list(self, context_list, loss_scale_list, inputs):
return res, res_loss_scale

def _jinja_encode(self, inputs: StdTemplateInputs):
self._prepare_system(inputs)
return super()._jinja_encode(inputs)


Expand Down
2 changes: 2 additions & 0 deletions swift/plugin/agent_template/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .qwen import QwenEnAgentTemplate, QwenEnParallelAgentTemplate, QwenZhAgentTemplate, QwenZhParallelAgentTemplate
from .qwen3_coder import Qwen3CoderAgentTemplate
from .react import ReactEnAgentTemplate, ReactZnAgentTemplate
from .seed_oss import SeedAgentTemplate
from .toolbench import ToolBenchAgentTemplate

agent_templates = {
Expand All @@ -31,6 +32,7 @@
'llama4': Llama4AgentTemplate,
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3.1
'deepseek_v3_1': DeepSeekV31AgentTemplate,
'seed_oss': SeedAgentTemplate,
# extra
'react_grpo': ReactGRPOAgentTemplate,
'mistral': MistralAgentTemplate
Expand Down
157 changes: 157 additions & 0 deletions swift/plugin/agent_template/seed_oss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import re
from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union

import json

from .base import BaseAgentTemplate

if TYPE_CHECKING:
from swift.llm.infer import Function
from swift.llm.template import Prompt


class SeedAgentTemplate(BaseAgentTemplate):
TOOL_CALL_START = '<seed:tool_call>'
TOOL_CALL_END = '</seed:tool_call>'
FUNCTION_TAG = 'function'
PARAMETER_TAG = 'parameter'

_PY_TYPE_MAPPING = {
'string': 'str',
'number': 'int',
'integer': 'int',
'boolean': 'bool',
'array': 'list',
}

@staticmethod
def _py_type(t: str) -> str:
return SeedAgentTemplate._PY_TYPE_MAPPING.get(t, 'Any')

def get_toolcall(self, response: str) -> List['Function']:
from swift.llm.infer import Function

res_list = re.findall(rf'{self.TOOL_CALL_START}(.+?){self.TOOL_CALL_END}', response, re.DOTALL)
if not res_list:
return super().get_toolcall(response)

functions = []
for res in res_list:
func_name_match = re.search(rf'<{self.FUNCTION_TAG}=([^>]+)>', res)
if not func_name_match:
continue

func_name = func_name_match.group(1)
param_matches = re.findall(rf'<{self.PARAMETER_TAG}=([^>]+)>(.*?)</{self.PARAMETER_TAG}>', res, re.DOTALL)
arguments = {name: value for name, value in param_matches}
functions.append(Function(name=func_name, arguments=arguments))

return functions

def _get_tool_responses(self, tool_messages: List[dict]) -> str:
responses = [f"<seed:bos>tool\n{tool_message['content']}<seed:eos>" for tool_message in tool_messages]
return ''.join(responses) + '<seed:bos>assistant\n'

def _format_tool_responses(
self,
assistant_content: str,
tool_messages: List[dict],
) -> Tuple[str, 'Prompt']:
with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
if with_action:
return super()._format_tool_responses(assistant_content, tool_messages)

formatted_tool_responses = self._get_tool_responses(tool_messages)
return assistant_content, ['<seed:eos>', formatted_tool_responses]

def _build_tool_def_string(self, tool: dict) -> str:
"""Helper to build a single tool definition string."""
func = tool.get('function', {})
func_name = func.get('name')

if not func_name:
return ''

parameters = func.get('parameters', {})
properties = parameters.get('properties', {})
params = [
f"{name}: {self._py_type(spec.get('type', 'any'))}" for name, spec in properties.items()
if isinstance(spec, dict)
]
param_str = ','.join(params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better readability and to follow PEP 8 conventions, it's good practice to add a space after the comma when joining strings.

Suggested change
param_str = ','.join(params)
param_str = ', '.join(params)


docstring_parts = [' """', f' {func.get("description", "").strip()}']

if properties:
docstring_parts.append('\n Args:')
required_params = parameters.get('required', [])
for name, spec in properties.items():
if isinstance(spec, dict):
req_tag = '[必填]' if name in required_params else '[选填]'
desc = spec.get('description', '')
type_str = self._py_type(spec.get('type', 'any'))
docstring_parts.append(f' - {name} ({type_str}) {req_tag}: {desc}')

returns_props = func.get('returns', {}).get('properties', {})
if returns_props:
docstring_parts.append('\n Returns:')
for name, spec in returns_props.items():
desc = spec.get('description', '')
type_str = self._py_type(spec.get('type', 'any'))
docstring_parts.append(f' - {name} ({type_str}): {desc}')

docstring_parts.append('\n """')
docstring = '\n'.join(docstring_parts)

return f'Function:\ndef {func_name}({param_str}):\n{docstring}'

def _format_tools(self, tools: List[Union[str, dict]], system: Optional[str] = None, user_message=None) -> str:
if not tools:
return system or ''

tool_defs = [
tool_def for tool in tools if (wrapped_tool := self.wrap_tool(tool)).get('type') == 'function' and
(tool_def := self._build_tool_def_string(wrapped_tool)) != ''
]
tool_defs_joined = '\n\n'.join(tool_defs)

tool_call_format_instruction = (
'工具调用请遵循如下格式:\n'
f'{self.TOOL_CALL_START}\n'
f'<{self.FUNCTION_TAG}=example_function_name>\n'
f'<{self.PARAMETER_TAG}=example_parameter_1>value_1</{self.PARAMETER_TAG}>\n'
f'<{self.PARAMETER_TAG}=example_parameter_2>This is the value for the second parameter\n'
'that can span\n'
f'multiple lines</{self.PARAMETER_TAG}>\n'
f'</{self.FUNCTION_TAG}>\n'
f'{self.TOOL_CALL_END}')

split_token = '<seed:eos><seed:bos>system'

if system and split_token in system:
parts = system.split(split_token, 1)
return f'{parts[0]}\n\n{tool_defs_joined}\n{tool_call_format_instruction}\n{split_token}{parts[1]}'
else:
doubao_prompt = ('You are Doubao, a helpful AI assistant. '
'You may call one or more functions to assist with the user query.')
return (f'{doubao_prompt}\n\n{tool_defs_joined}\n{tool_call_format_instruction}\n'
f'{split_token}\n{system or ""}')

def _format_tool_calls(self, tool_call_messages: List[dict]) -> str:
formatted_calls = []
for message in tool_call_messages:
tool_call = self._parse_tool_call(message['content'])
func_name = tool_call['name']
arguments = tool_call.get('arguments', {})

call_parts = [f'<{self.FUNCTION_TAG}={func_name}>']
for arg_name, arg_value in arguments.items():
arg_value_str = arg_value if isinstance(arg_value, str) else json.dumps(arg_value, ensure_ascii=False)
call_parts.append(f'<{self.PARAMETER_TAG}={arg_name}>{arg_value_str}</{self.PARAMETER_TAG}>')

call_parts.append(f'</{self.FUNCTION_TAG}>')
call_parts_joined = '\n'.join(call_parts)

full_call = f'{self.TOOL_CALL_START}\n{call_parts_joined}\n{self.TOOL_CALL_END}'
formatted_calls.append(full_call)
return '\n'.join(formatted_calls)
2 changes: 1 addition & 1 deletion swift/plugin/loss_scale/config/ignore_empty_think.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"<think>\\s*</think>\\s*": [0.0],
"<seed:think><seed:cot_budget_reflect>The current thinking budget is 0, so I will directly start answering the question.</seed:cot_budget_reflect>\n\n</seed:think>\\s*": [0.0]
"<seed:think><seed:cot_budget_reflect>The current thinking budget is 0, so I will directly start answering the question.</seed:cot_budget_reflect>\n</seed:think>\\s*": [0.0]
}
72 changes: 71 additions & 1 deletion tests/test_align/test_template/test_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,75 @@ def test_deepseek_v3_1():
assert encoded['input_ids'][-122:] == encoded2['input_ids'][1:]


def test_seed_oss():
agent_template = agent_templates['seed_oss']()

engine = PtEngine('ByteDance-Seed/Seed-OSS-36B-Instruct', load_model=False, download_model=False)

template = engine.default_template
template.agent_template = agent_template

dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
data = dataset[6]
# To test multiple tool calls and responses, we duplicate some messages.
data['messages'].insert(1, data['messages'][1])
data['messages'].insert(3, data['messages'][3])

# Incomplete tool function will cause seed template to throw an error.
data['tools'] = [('{\n'
' "name": "convert_temperature",\n'
' "description": "Convert temperature from one unit to another",\n'
' "parameters": {\n'
' "type": "object",\n'
' "properties": {\n'
' "temperature": {\n'
' "type": "number",\n'
' "description": "The temperature value"\n'
' },\n'
' "from_unit": {\n'
' "type": "string",\n'
' "description": "The unit to convert from"\n'
' },\n'
' "to_unit": {\n'
' "type": "string",\n'
' "description": "The unit to convert to"\n'
' }\n'
' },\n'
' "required": [\n'
' "temperature",\n'
' "from_unit",\n'
' "to_unit"\n'
' ]\n'
' }\n'
'}'),
('{\n'
' "name": "get_current_date",\n'
' "description": "Get the current date",\n'
' "parameters": {\n'
' "type": "object",\n'
' "properties": {\n'
' "date": {\n'
' "type": "number",\n'
' "description": "The date value"}}}\n'
'}')]

data['thinking_budget'] = 0

template.template_backend = 'swift'
template.set_mode('train')
encoded = template.encode(data)
print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
print(f'labels: {template.safe_decode(encoded["labels"])}')
import re
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to PEP 8, imports should usually be at the top of the file, not inside a function. Please move import re to the top of the file with other imports.

expected_input_ids = re.sub(
r'<seed:think>.*?</seed:think>', '', template.safe_decode(encoded['input_ids']), flags=re.DOTALL)
template.template_backend = 'jinja'
encoded2 = template.encode(data)
print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
print(f'labels: {template.safe_decode(encoded2["labels"])}')
assert template.safe_decode(encoded2['input_ids']) == expected_input_ids


if __name__ == '__main__':
from swift.plugin import agent_templates
from swift.llm import PtEngine, InferRequest, RequestConfig, load_dataset
Expand All @@ -460,4 +529,5 @@ def test_deepseek_v3_1():
# test_hunyuan()
# test_glm4_5()
# test_qwen3_coder()
test_deepseek_v3_1()
# test_deepseek_v3_1()
test_seed_oss()
Loading