|
| 1 | +import json |
| 2 | +import re |
| 3 | +from typing import Any |
| 4 | + |
| 5 | +from bfcl_eval.model_handler.local_inference.base_oss_handler import OSSHandler |
| 6 | +from bfcl_eval.model_handler.utils import convert_to_function_call |
| 7 | +from overrides import override |
| 8 | + |
| 9 | + |
| 10 | +class PelicanVLFCHandler(OSSHandler): |
| 11 | + def __init__( |
| 12 | + self, |
| 13 | + model_name, |
| 14 | + temperature, |
| 15 | + registry_name, |
| 16 | + is_fc_model, |
| 17 | + dtype="bfloat16", |
| 18 | + **kwargs, |
| 19 | + ) -> None: |
| 20 | + super().__init__(model_name, temperature, registry_name, is_fc_model, **kwargs) |
| 21 | + # Pelican FC handler expects FC behavior |
| 22 | + self.is_fc_model = True |
| 23 | + # Pelican models name on huggingface may be the base name without the "-FC" suffix |
| 24 | + self.model_name_huggingface = model_name.replace("-FC", "") |
| 25 | + |
| 26 | + @override |
| 27 | + def decode_ast(self, result, language, has_tool_call_tag): |
| 28 | + # Model response is of the form: |
| 29 | + # "<tool_call>\n{\"name\": \"spotify.play\", \"arguments\": {\"artist\": \"Taylor Swift\", \"duration\": 20}}\n</tool_call>\n<tool_call>\n{\"name\": \"spotify.play\", \"arguments\": {\"artist\": \"Maroon 5\", \"duration\": 15}}\n</tool_call>" |
| 30 | + tool_calls = self._extract_tool_calls(result) |
| 31 | + if type(tool_calls) != list or any(type(item) != dict for item in tool_calls): |
| 32 | + raise ValueError(f"Model did not return a list of function calls: {result}") |
| 33 | + return [ |
| 34 | + {call["name"]: {k: v for k, v in call["arguments"].items()}} |
| 35 | + for call in tool_calls |
| 36 | + ] |
| 37 | + |
| 38 | + @override |
| 39 | + def decode_execute(self, result, has_tool_call_tag): |
| 40 | + tool_calls = self._extract_tool_calls(result) |
| 41 | + if type(tool_calls) != list or any(type(item) != dict for item in tool_calls): |
| 42 | + raise ValueError(f"Model did not return a list of function calls: {result}") |
| 43 | + decoded_result = [] |
| 44 | + for item in tool_calls: |
| 45 | + if type(item) == str: |
| 46 | + item = eval(item) |
| 47 | + decoded_result.append({item["name"]: item["arguments"]}) |
| 48 | + return convert_to_function_call(decoded_result) |
| 49 | + |
| 50 | + @override |
| 51 | + def _format_prompt(self, messages, function): |
| 52 | + """ |
| 53 | + "chat_template": |
| 54 | + {% set image_count = namespace(value=0) %} |
| 55 | + {% set video_count = namespace(value=0) %} |
| 56 | + {% for message in messages %} |
| 57 | + {% if loop.first and message['role'] != 'system' %} |
| 58 | + <|im_start|> |
| 59 | + system\nYou are a helpful assistant. |
| 60 | + <|im_end|>\n |
| 61 | + {% endif %} |
| 62 | + <|im_start|> |
| 63 | + {{ message['role'] }}\n |
| 64 | + {% if message['content'] is string %} |
| 65 | + {{ message['content'] }} |
| 66 | + <|im_end|>\n |
| 67 | + {% else %} |
| 68 | + {% for content in message['content'] %} |
| 69 | + {% if content['type'] == 'image' or 'image' in content or 'image_url' in content %} |
| 70 | + {% set image_count.value = image_count.value + 1 %} |
| 71 | + {% if add_vision_id %} |
| 72 | + Picture {{ image_count.value }}: |
| 73 | + {% endif %} |
| 74 | + <|vision_start|><|image_pad|><|vision_end|> |
| 75 | + {% elif content['type'] == 'video' or 'video' in content %} |
| 76 | + {% set video_count.value = video_count.value + 1 %} |
| 77 | + {% if add_vision_id %} |
| 78 | + Video {{ video_count.value }}: |
| 79 | + {% endif %} |
| 80 | + <|vision_start|><|video_pad|><|vision_end|> |
| 81 | + {% elif 'text' in content %} |
| 82 | + {{ content['text'] }} |
| 83 | + {% endif %} |
| 84 | + {% endfor %} |
| 85 | + </im_end>\n |
| 86 | + {% endif %} |
| 87 | + {% endfor %} |
| 88 | + {% if add_generation_prompt %} |
| 89 | + <|im_start|>assistant\n |
| 90 | + {% endif %}" |
| 91 | + """ |
| 92 | + add_vision_id=False |
| 93 | + add_generation_prompt=True |
| 94 | + formatted_prompt = "" |
| 95 | + image_count = 0 |
| 96 | + video_count = 0 |
| 97 | + first_system_processed = False |
| 98 | + |
| 99 | + # ===== 1. 函数调用处理 ===== |
| 100 | + if function and len(function) > 0: |
| 101 | + formatted_prompt += "<|im_start|>system\n" |
| 102 | + |
| 103 | + # 检查第一条消息是否为系统消息 |
| 104 | + if messages and messages[0]['role'] == 'system': |
| 105 | + formatted_prompt += messages[0]['content'] + "\n\n" |
| 106 | + first_system_processed = True |
| 107 | + |
| 108 | + # 添加函数调用说明 |
| 109 | + formatted_prompt += "# Tools\n\nYou may call one or more function to assist with the user query." |
| 110 | + formatted_prompt += "\n\nYou are provided with function signatures within <tools></tools> XML tags:" |
| 111 | + formatted_prompt += "\n<tools>" |
| 112 | + |
| 113 | + for func in function: |
| 114 | + formatted_prompt += f"\n{json.dumps(func)}" |
| 115 | + |
| 116 | + formatted_prompt += "\n</tools>" |
| 117 | + formatted_prompt += "\n\nFor each function call, return a json object with functions name and arguments within <tool_call></tool_call> XML tags:" |
| 118 | + formatted_prompt += '\n<tool_call>\n{"name": <function-name>, "arguments": <args-json-object>}\n</tool_call>' |
| 119 | + formatted_prompt += "<|im_end|>\n" |
| 120 | + |
| 121 | + # ===== 2. 系统消息处理 ===== |
| 122 | + # 处理未在函数部分处理的系统消息 |
| 123 | + if messages and messages[0]['role'] == 'system' and not first_system_processed: |
| 124 | + formatted_prompt += f"<|im_start|>system\n{messages[0]['content']}<|im_end|>\n" |
| 125 | + first_system_processed = True |
| 126 | + elif not function and (not messages or messages[0]['role'] != 'system'): |
| 127 | + # 添加默认系统消息 |
| 128 | + formatted_prompt += "<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n" |
| 129 | + |
| 130 | + # ===== 3. 消息遍历处理 ===== |
| 131 | + for idx, message in enumerate(messages): |
| 132 | + role = message['role'] |
| 133 | + content = message['content'] |
| 134 | + |
| 135 | + # 跳过已处理的系统消息 |
| 136 | + if idx == 0 and role == 'system' and first_system_processed: |
| 137 | + continue |
| 138 | + |
| 139 | + # 添加消息开始标签 |
| 140 | + formatted_prompt += f"<|im_start|>{role}\n" |
| 141 | + |
| 142 | + # 处理工具响应消息 |
| 143 | + if role == "tool": |
| 144 | + formatted_prompt += f"<tool_response>\n{content}\n</tool_response>" |
| 145 | + |
| 146 | + # 处理助理消息(可能包含函数调用) |
| 147 | + elif role == "assistant": |
| 148 | + # 处理函数调用 |
| 149 | + if "tool_calls" in message: |
| 150 | + for call in message["tool_calls"]: |
| 151 | + func = call.get("function", call) |
| 152 | + name = func["name"] |
| 153 | + args = func["arguments"] |
| 154 | + |
| 155 | + if isinstance(args, dict): |
| 156 | + args = json.dumps(args, ensure_ascii=False) |
| 157 | + |
| 158 | + formatted_prompt += f'<tool_call>\n{{"name": "{name}", "arguments": {args}}}\n</tool_call>' |
| 159 | + |
| 160 | + # 处理常规内容 |
| 161 | + if content: |
| 162 | + formatted_prompt += content |
| 163 | + |
| 164 | + # 处理用户消息(可能包含多模态内容) |
| 165 | + elif role == "user": |
| 166 | + if isinstance(content, str): |
| 167 | + formatted_prompt += content |
| 168 | + elif isinstance(content, list): |
| 169 | + for content_part in content: |
| 170 | + # 处理图像内容 |
| 171 | + if ('type' in content_part and content_part['type'] == 'image') or \ |
| 172 | + 'image' in content_part or \ |
| 173 | + 'image_url' in content_part: |
| 174 | + image_count += 1 |
| 175 | + if add_vision_id: |
| 176 | + formatted_prompt += f"Picture {image_count}: " |
| 177 | + formatted_prompt += "<|vision_start|><|image_pad|><|vision_end|>" |
| 178 | + |
| 179 | + # 处理视频内容 |
| 180 | + elif ('type' in content_part and content_part['type'] == 'video') or \ |
| 181 | + 'video' in content_part: |
| 182 | + video_count += 1 |
| 183 | + if add_vision_id: |
| 184 | + formatted_prompt += f"Video {video_count}: " |
| 185 | + formatted_prompt += "<|vision_start|><|video_pad|><|vision_end|>" |
| 186 | + |
| 187 | + # 处理文本内容 |
| 188 | + elif 'text' in content_part: |
| 189 | + formatted_prompt += content_part['text'] |
| 190 | + |
| 191 | + # 添加消息结束标签 |
| 192 | + formatted_prompt += "<|im_end|>\n" |
| 193 | + |
| 194 | + # ===== 4. 添加助理提示 ===== |
| 195 | + if add_generation_prompt: |
| 196 | + formatted_prompt += "<|im_start|>assistant\n" |
| 197 | + |
| 198 | + # print("=================================start of formatted prompt=================================") |
| 199 | + # print(formatted_prompt) |
| 200 | + |
| 201 | + return formatted_prompt |
| 202 | + |
| 203 | + @override |
| 204 | + def _pre_query_processing_prompting(self, test_entry: dict) -> dict: |
| 205 | + functions: list = test_entry["function"] |
| 206 | + |
| 207 | + # FC models use its own system prompt, so no need to add any message |
| 208 | + |
| 209 | + return {"message": [], "function": functions} |
| 210 | + |
| 211 | + @override |
| 212 | + def _parse_query_response_prompting(self, api_response: Any) -> dict: |
| 213 | + model_response = api_response.choices[0].text |
| 214 | + extracted_tool_calls = self._extract_tool_calls(model_response) |
| 215 | + |
| 216 | + reasoning_content = "" |
| 217 | + cleaned_response = model_response |
| 218 | + if "</think>" in model_response: |
| 219 | + parts = model_response.split("</think>") |
| 220 | + reasoning_content = parts[0].rstrip("\n").split("<think>")[-1].lstrip("\n") |
| 221 | + cleaned_response = parts[-1].lstrip("\n") |
| 222 | + |
| 223 | + if len(extracted_tool_calls) > 0: |
| 224 | + model_responses_message_for_chat_history = { |
| 225 | + "role": "assistant", |
| 226 | + "content": "", |
| 227 | + "tool_calls": extracted_tool_calls, |
| 228 | + } |
| 229 | + |
| 230 | + else: |
| 231 | + model_responses_message_for_chat_history = { |
| 232 | + "role": "assistant", |
| 233 | + "content": cleaned_response, |
| 234 | + } |
| 235 | + |
| 236 | + model_responses_message_for_chat_history["reasoning_content"] = reasoning_content |
| 237 | + |
| 238 | + return { |
| 239 | + "model_responses": cleaned_response, |
| 240 | + "reasoning_content": reasoning_content, |
| 241 | + "model_responses_message_for_chat_history": model_responses_message_for_chat_history, |
| 242 | + "input_token": api_response.usage.prompt_tokens, |
| 243 | + "output_token": api_response.usage.completion_tokens, |
| 244 | + } |
| 245 | + |
| 246 | + @override |
| 247 | + def _add_assistant_message_prompting( |
| 248 | + self, inference_data: dict, model_response_data: dict |
| 249 | + ) -> dict: |
| 250 | + inference_data["message"].append( |
| 251 | + model_response_data["model_responses_message_for_chat_history"], |
| 252 | + ) |
| 253 | + return inference_data |
| 254 | + |
| 255 | + @staticmethod |
| 256 | + def _extract_tool_calls(input_string): |
| 257 | + pattern = r"<tool_call>\n(.*?)\n</tool_call>" |
| 258 | + matches = re.findall(pattern, input_string, re.DOTALL) |
| 259 | + |
| 260 | + # Process matches into a list of dictionaries |
| 261 | + result = [] |
| 262 | + for match in matches: |
| 263 | + try: |
| 264 | + match = json.loads(match) |
| 265 | + result.append(match) |
| 266 | + except Exception as e: |
| 267 | + pass |
| 268 | + return result |
0 commit comments