-
Notifications
You must be signed in to change notification settings - Fork 536
Open
Labels
badcaseBad casesBad cases
Description
Description / 描述
Env
python: 3.11.2
torch: 2.7.1+cu126 cuda: 12.6 is_cuda: True
transformers: 4.55.0
checkpoint: openbmb/MiniCPM4-0.5B
Minimal reproduction script
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
TOOL_JINJA = r"""
{%- if tools %}
{{- '<|im_start|>system\n' -}}
{{- "# Tools\n\n" -}}
{{- "You may call one or more functions to assist with the user query.\n" -}}
{{- "You are provided with function signatures within <tools></tools> XML tags:\n<tools>" -}}
{%- for tool in tools %}
{{- "\n" + (tool | tojson) -}}
{%- endfor %}
{{- "\n</tools>\n\n" -}}
{{- "For each function call, return ONLY a json object within <tool_call></tool_call>:\n" -}}
{{- "<tool_call>\n{\"name\": \"...\", \"arguments\": {...}}\n</tool_call>" -}}
{{- '<|im_end|>\n' -}}
{%- endif %}
{%- for m in messages %}
{{- '<|im_start|>' + m['role'] + '\n' + m['content'] + '<|im_end|>\n' -}}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' -}}
{%- endif %}
""".strip("\n")
tools = [
{ "type": "function", "function": { "name": "SocialCommunication.InstantMessaging.getGroupMembers", "description": "查看指定群组的成员列表", "parameters": { "type": "object", "properties": { "group_name": {"type": "string", "description": "群组的名称"} }, "required": ["group_name"]}}},
{'type': 'function', 'function': {'name': 'Navigation.TrafficViolations.reportAccident', 'description': '上报交通事故及违章信息', 'parameters': {'type': 'object', 'properties': {'plate_number': {'type': 'string', 'description': '车牌号'}, 'location': {'type': 'string', 'description': '事故地点'}}, 'required': ['plate_number', 'location']}}},
{'type': 'function', 'function': {'name': 'SocialCommunication.InstantMessaging.sendSystemMessage', 'description': '向指定号码或者联系人发送系统短信', 'parameters': {'type': 'object', 'properties': {'receiver_number': {'type': 'string', 'description': '接收短信的手机号码'}, 'receiver_name': {'type': 'string', 'description': '接收短信的用户名'}, 'message_content': {'type': 'string', 'description': '要发送的短信文本内容'}, 'attachment_path': {'type': 'string', 'description': '图片或者文件地址'}}, 'required': ['receiver_number', 'receiver_name', 'message_content', 'attachment_path']}}}
]
samples = [
[{"role": "user", "content": "EK:尊敬的车主您好,\n\n关于您名下的车辆(车牌号:京C98765),我们在最近的常规检查中发现该车辆的年度检验即将到期。为了确保您的行车安全并避免因逾期未检带来的罚款或其他不便,我们强烈建议您尽快安排时间进行车辆年检。若有任何疑问或需要预约服务,欢迎随时联系我们,我们将竭诚为您服务。\n\n祝您行车愉快,安全每一天!\n\n敬礼\nXX汽车服务中心\n2023年10月10日\nquery:我在长安街遇到了一个单车事故,车子损坏得很严重,车牌号是李华发给我的邮件里提到的,需要上报。"}],
[{"role": "user", "content": "查询群组人员"}],
]
def has_tool_call(text: str) -> bool:
return "<tool_call>" in text and "</tool_call>" in text
def run_one(model, tokenizer, messages, do_sample, temperature, top_p, max_new_tokens, seed):
if seed is not None:
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
prompt = tokenizer.apply_chat_template(
messages,
tools=tools,
tokenize=False,
add_generation_prompt=True,
)
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
out = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=do_sample,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id
)
gen_ids = out[0][inputs["input_ids"].shape[-1]:]
gen_text = tokenizer.decode(gen_ids, skip_special_tokens=False)
return prompt, gen_text, has_tool_call(gen_text)
def main():
checkpoint = "openbmb/MiniCPM4-0.5B"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True)
tokenizer.chat_template = TOOL_JINJA
model = AutoModelForCausalLM.from_pretrained(
checkpoint,
torch_dtype=dtype,
trust_remote_code=True,
).to(device)
model.eval()
settings = [
dict(name="greedy", do_sample=False, temperature=1.0, top_p=1.0),
dict(name="sample", do_sample=True, temperature=0.7, top_p=0.7),
]
for cfg in settings:
print(f"\n[DECODE MODE] {cfg['name']}", "=" * 80)
for i, messages in enumerate(samples, 1):
prompt, gen_text, ok = run_one(
model=model,
tokenizer=tokenizer,
messages=messages,
do_sample=cfg["do_sample"],
temperature=cfg["temperature"],
top_p=cfg["top_p"],
max_new_tokens=256,
seed=42,
)
print(f"\n--- Sample {i} | has_tool_call={ok} ---")
# print("[PROMPT]")
# print(prompt)
# print("[GEN]")
print(gen_text)
if __name__ == "__main__":
main()Case Explaination / 案例解释
I’m testing the tool-calling ability with openbmb/MiniCPM4-0.5B, expecting the model to emit <tool_call> and a correct function call. For example, for sample 1, it should call:
<tool_call>
{"name":"Navigation.TrafficViolations.reportAccident",
"arguments":{"plate_number":"京C98765","location":"长安街"}}
</tool_call>But it outputs natural language withoa ut function call:
[DECODE MODE] greedy ================================================================================
--- Sample 1 | has_tool_call=False ---
对不起,我无法提供关于李华发邮件中提到的车辆损坏事故的任何信息。请提供更多关于事故的详细信息,例如车牌号、事故地点和事故照片,以便我们能够尽快处理这个问题。<|im_end|>
--- Sample 2 | has_tool_call=False ---
请提供群组的名称,以便我为您查询成员列表。<|im_end|>
[DECODE MODE] sample ================================================================================
--- Sample 1 | has_tool_call=False ---
对不起,我无法帮助您完成这个任务。因为我没有收到关于李华发提供的车牌号和事故信息的邮件。如果您能提供更多信息,我将很乐意协助您完成上报交通事故及违章信息的任务。<|im_end|>
--- Sample 2 | has_tool_call=False ---
请提供群组的名称,以便我为您查询成员列表。<|im_end|>
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
badcaseBad casesBad cases