diff --git a/swift/llm/template/base.py b/swift/llm/template/base.py
index 017674ce6b..5598200ac0 100644
--- a/swift/llm/template/base.py
+++ b/swift/llm/template/base.py
@@ -998,6 +998,8 @@ def _jinja_encode(self, inputs: StdTemplateInputs):
kwargs = {}
if inputs.tools:
kwargs['tools'] = inputs.tools
+ if 'thinking_budget' in inputs.extra_kwargs:
+ kwargs['thinking_budget'] = inputs.extra_kwargs.get('thinking_budget', 0)
text = self.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=add_generation_prompt, **kwargs)
answer_len = 1 if self.is_training else 0
diff --git a/swift/llm/template/template/seed.py b/swift/llm/template/template/seed.py
index 701d547d63..3c9255ff31 100644
--- a/swift/llm/template/template/seed.py
+++ b/swift/llm/template/template/seed.py
@@ -94,7 +94,7 @@ def insert_budget_markers(text: str, tokenizer, interval: int, total_budget: int
return '\n'.join(result)
else:
return ('The current thinking budget is 0, so I will '
- 'directly start answering the question.\n\n')
+ 'directly start answering the question.\n')
def _prepare_system(self, inputs):
budget = self.get_thinking_budget(inputs)
@@ -143,7 +143,7 @@ def _swift_prepare_inputs(self, inputs: StdTemplateInputs):
message['content'] = (
'The current thinking budget is 0, '
'so I will directly start answering the question.'
- '\n\n') + message['content']
+ '\n') + message['content']
def _simplify_context_list(self, context_list, loss_scale_list, inputs):
res, res_loss_scale = super()._simplify_context_list(context_list, loss_scale_list, inputs)
@@ -154,7 +154,6 @@ def _simplify_context_list(self, context_list, loss_scale_list, inputs):
return res, res_loss_scale
def _jinja_encode(self, inputs: StdTemplateInputs):
- self._prepare_system(inputs)
return super()._jinja_encode(inputs)
diff --git a/swift/plugin/agent_template/__init__.py b/swift/plugin/agent_template/__init__.py
index 86d18e7c6e..797bef64c1 100644
--- a/swift/plugin/agent_template/__init__.py
+++ b/swift/plugin/agent_template/__init__.py
@@ -9,6 +9,7 @@
from .qwen import QwenEnAgentTemplate, QwenEnParallelAgentTemplate, QwenZhAgentTemplate, QwenZhParallelAgentTemplate
from .qwen3_coder import Qwen3CoderAgentTemplate
from .react import ReactEnAgentTemplate, ReactZnAgentTemplate
+from .seed_oss import SeedAgentTemplate
from .toolbench import ToolBenchAgentTemplate
agent_templates = {
@@ -31,6 +32,7 @@
'llama4': Llama4AgentTemplate,
# ref: https://huggingface.co/deepseek-ai/DeepSeek-V3.1
'deepseek_v3_1': DeepSeekV31AgentTemplate,
+ 'seed_oss': SeedAgentTemplate,
# extra
'react_grpo': ReactGRPOAgentTemplate,
'mistral': MistralAgentTemplate
diff --git a/swift/plugin/agent_template/seed_oss.py b/swift/plugin/agent_template/seed_oss.py
new file mode 100644
index 0000000000..97d6891012
--- /dev/null
+++ b/swift/plugin/agent_template/seed_oss.py
@@ -0,0 +1,157 @@
+import re
+from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
+
+import json
+
+from .base import BaseAgentTemplate
+
+if TYPE_CHECKING:
+ from swift.llm.infer import Function
+ from swift.llm.template import Prompt
+
+
+class SeedAgentTemplate(BaseAgentTemplate):
+ TOOL_CALL_START = ''
+ TOOL_CALL_END = ''
+ FUNCTION_TAG = 'function'
+ PARAMETER_TAG = 'parameter'
+
+ _PY_TYPE_MAPPING = {
+ 'string': 'str',
+ 'number': 'int',
+ 'integer': 'int',
+ 'boolean': 'bool',
+ 'array': 'list',
+ }
+
+ @staticmethod
+ def _py_type(t: str) -> str:
+ return SeedAgentTemplate._PY_TYPE_MAPPING.get(t, 'Any')
+
+ def get_toolcall(self, response: str) -> List['Function']:
+ from swift.llm.infer import Function
+
+ res_list = re.findall(rf'{self.TOOL_CALL_START}(.+?){self.TOOL_CALL_END}', response, re.DOTALL)
+ if not res_list:
+ return super().get_toolcall(response)
+
+ functions = []
+ for res in res_list:
+ func_name_match = re.search(rf'<{self.FUNCTION_TAG}=([^>]+)>', res)
+ if not func_name_match:
+ continue
+
+ func_name = func_name_match.group(1)
+ param_matches = re.findall(rf'<{self.PARAMETER_TAG}=([^>]+)>(.*?){self.PARAMETER_TAG}>', res, re.DOTALL)
+ arguments = {name: value for name, value in param_matches}
+ functions.append(Function(name=func_name, arguments=arguments))
+
+ return functions
+
+ def _get_tool_responses(self, tool_messages: List[dict]) -> str:
+ responses = [f"tool\n{tool_message['content']}" for tool_message in tool_messages]
+ return ''.join(responses) + 'assistant\n'
+
+ def _format_tool_responses(
+ self,
+ assistant_content: str,
+ tool_messages: List[dict],
+ ) -> Tuple[str, 'Prompt']:
+ with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
+ if with_action:
+ return super()._format_tool_responses(assistant_content, tool_messages)
+
+ formatted_tool_responses = self._get_tool_responses(tool_messages)
+ return assistant_content, ['', formatted_tool_responses]
+
+ def _build_tool_def_string(self, tool: dict) -> str:
+ """Helper to build a single tool definition string."""
+ func = tool.get('function', {})
+ func_name = func.get('name')
+
+ if not func_name:
+ return ''
+
+ parameters = func.get('parameters', {})
+ properties = parameters.get('properties', {})
+ params = [
+ f"{name}: {self._py_type(spec.get('type', 'any'))}" for name, spec in properties.items()
+ if isinstance(spec, dict)
+ ]
+ param_str = ','.join(params)
+
+ docstring_parts = [' """', f' {func.get("description", "").strip()}']
+
+ if properties:
+ docstring_parts.append('\n Args:')
+ required_params = parameters.get('required', [])
+ for name, spec in properties.items():
+ if isinstance(spec, dict):
+ req_tag = '[必填]' if name in required_params else '[选填]'
+ desc = spec.get('description', '')
+ type_str = self._py_type(spec.get('type', 'any'))
+ docstring_parts.append(f' - {name} ({type_str}) {req_tag}: {desc}')
+
+ returns_props = func.get('returns', {}).get('properties', {})
+ if returns_props:
+ docstring_parts.append('\n Returns:')
+ for name, spec in returns_props.items():
+ desc = spec.get('description', '')
+ type_str = self._py_type(spec.get('type', 'any'))
+ docstring_parts.append(f' - {name} ({type_str}): {desc}')
+
+ docstring_parts.append('\n """')
+ docstring = '\n'.join(docstring_parts)
+
+ return f'Function:\ndef {func_name}({param_str}):\n{docstring}'
+
+ def _format_tools(self, tools: List[Union[str, dict]], system: Optional[str] = None, user_message=None) -> str:
+ if not tools:
+ return system or ''
+
+ tool_defs = [
+ tool_def for tool in tools if (wrapped_tool := self.wrap_tool(tool)).get('type') == 'function' and
+ (tool_def := self._build_tool_def_string(wrapped_tool)) != ''
+ ]
+ tool_defs_joined = '\n\n'.join(tool_defs)
+
+ tool_call_format_instruction = (
+ '工具调用请遵循如下格式:\n'
+ f'{self.TOOL_CALL_START}\n'
+ f'<{self.FUNCTION_TAG}=example_function_name>\n'
+ f'<{self.PARAMETER_TAG}=example_parameter_1>value_1{self.PARAMETER_TAG}>\n'
+ f'<{self.PARAMETER_TAG}=example_parameter_2>This is the value for the second parameter\n'
+ 'that can span\n'
+ f'multiple lines{self.PARAMETER_TAG}>\n'
+ f'{self.FUNCTION_TAG}>\n'
+ f'{self.TOOL_CALL_END}')
+
+ split_token = 'system'
+
+ if system and split_token in system:
+ parts = system.split(split_token, 1)
+ return f'{parts[0]}\n\n{tool_defs_joined}\n{tool_call_format_instruction}\n{split_token}{parts[1]}'
+ else:
+ doubao_prompt = ('You are Doubao, a helpful AI assistant. '
+ 'You may call one or more functions to assist with the user query.')
+ return (f'{doubao_prompt}\n\n{tool_defs_joined}\n{tool_call_format_instruction}\n'
+ f'{split_token}\n{system or ""}')
+
+ def _format_tool_calls(self, tool_call_messages: List[dict]) -> str:
+ formatted_calls = []
+ for message in tool_call_messages:
+ tool_call = self._parse_tool_call(message['content'])
+ func_name = tool_call['name']
+ arguments = tool_call.get('arguments', {})
+
+ call_parts = [f'<{self.FUNCTION_TAG}={func_name}>']
+ for arg_name, arg_value in arguments.items():
+ arg_value_str = arg_value if isinstance(arg_value, str) else json.dumps(arg_value, ensure_ascii=False)
+ call_parts.append(f'<{self.PARAMETER_TAG}={arg_name}>{arg_value_str}{self.PARAMETER_TAG}>')
+
+ call_parts.append(f'{self.FUNCTION_TAG}>')
+ call_parts_joined = '\n'.join(call_parts)
+
+ full_call = f'{self.TOOL_CALL_START}\n{call_parts_joined}\n{self.TOOL_CALL_END}'
+ formatted_calls.append(full_call)
+ return '\n'.join(formatted_calls)
diff --git a/swift/plugin/loss_scale/config/ignore_empty_think.json b/swift/plugin/loss_scale/config/ignore_empty_think.json
index 16d08e82ba..70a399affb 100644
--- a/swift/plugin/loss_scale/config/ignore_empty_think.json
+++ b/swift/plugin/loss_scale/config/ignore_empty_think.json
@@ -1,4 +1,4 @@
{
"\\s*\\s*": [0.0],
- "The current thinking budget is 0, so I will directly start answering the question.\n\n\\s*": [0.0]
+ "The current thinking budget is 0, so I will directly start answering the question.\n\\s*": [0.0]
}
diff --git a/tests/test_align/test_template/test_agent.py b/tests/test_align/test_template/test_agent.py
index 0a4b8bca36..9f3d347606 100644
--- a/tests/test_align/test_template/test_agent.py
+++ b/tests/test_align/test_template/test_agent.py
@@ -442,6 +442,75 @@ def test_deepseek_v3_1():
assert encoded['input_ids'][-122:] == encoded2['input_ids'][1:]
+def test_seed_oss():
+ agent_template = agent_templates['seed_oss']()
+
+ engine = PtEngine('ByteDance-Seed/Seed-OSS-36B-Instruct', load_model=False, download_model=False)
+
+ template = engine.default_template
+ template.agent_template = agent_template
+
+ dataset = load_dataset('AI-ModelScope/function-calling-chatml')[0]
+ data = dataset[6]
+ # To test multiple tool calls and responses, we duplicate some messages.
+ data['messages'].insert(1, data['messages'][1])
+ data['messages'].insert(3, data['messages'][3])
+
+ # Incomplete tool function will cause seed template to throw an error.
+ data['tools'] = [('{\n'
+ ' "name": "convert_temperature",\n'
+ ' "description": "Convert temperature from one unit to another",\n'
+ ' "parameters": {\n'
+ ' "type": "object",\n'
+ ' "properties": {\n'
+ ' "temperature": {\n'
+ ' "type": "number",\n'
+ ' "description": "The temperature value"\n'
+ ' },\n'
+ ' "from_unit": {\n'
+ ' "type": "string",\n'
+ ' "description": "The unit to convert from"\n'
+ ' },\n'
+ ' "to_unit": {\n'
+ ' "type": "string",\n'
+ ' "description": "The unit to convert to"\n'
+ ' }\n'
+ ' },\n'
+ ' "required": [\n'
+ ' "temperature",\n'
+ ' "from_unit",\n'
+ ' "to_unit"\n'
+ ' ]\n'
+ ' }\n'
+ '}'),
+ ('{\n'
+ ' "name": "get_current_date",\n'
+ ' "description": "Get the current date",\n'
+ ' "parameters": {\n'
+ ' "type": "object",\n'
+ ' "properties": {\n'
+ ' "date": {\n'
+ ' "type": "number",\n'
+ ' "description": "The date value"}}}\n'
+ '}')]
+
+ data['thinking_budget'] = 0
+
+ template.template_backend = 'swift'
+ template.set_mode('train')
+ encoded = template.encode(data)
+ print(f'input_ids: {template.safe_decode(encoded["input_ids"])}')
+ print(f'labels: {template.safe_decode(encoded["labels"])}')
+ import re
+ expected_input_ids = re.sub(
+ r'.*?', '', template.safe_decode(encoded['input_ids']), flags=re.DOTALL)
+ template.template_backend = 'jinja'
+ encoded2 = template.encode(data)
+ print(f'input_ids: {template.safe_decode(encoded2["input_ids"])}')
+ print(f'labels: {template.safe_decode(encoded2["labels"])}')
+ assert template.safe_decode(encoded2['input_ids']) == expected_input_ids
+
+
if __name__ == '__main__':
from swift.plugin import agent_templates
from swift.llm import PtEngine, InferRequest, RequestConfig, load_dataset
@@ -460,4 +529,5 @@ def test_deepseek_v3_1():
# test_hunyuan()
# test_glm4_5()
# test_qwen3_coder()
- test_deepseek_v3_1()
+ # test_deepseek_v3_1()
+ test_seed_oss()