Skip to content

Commit 25bb516

Browse files
author
Ubuntu
committed
Update handler in support of v3
1 parent 9ba52e7 commit 25bb516

File tree

3 files changed

+114
-0
lines changed

3 files changed

+114
-0
lines changed

berkeley-function-call-leaderboard/bfcl/eval_checker/model_metadata.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,18 @@
497497
"Microsoft",
498498
"MIT",
499499
],
500+
"empower-dev/llama3-empower-functions-small-v1.1": [
501+
"Empower-Fucntions-Small-v1.1 (FC)",
502+
"https://huggingface.co/empower-dev/llama3-empower-functions-small-v1.1",
503+
"Empower.dev",
504+
"apache-2.0"
505+
],
506+
"empower-dev/llama3-empower-functions-large-v1.1": [
507+
"Empower-Fucntions-Large-v1.1 (FC)",
508+
"https://huggingface.co/empower-dev/llama3-empower-functions-large-v1.1",
509+
"Empower.dev",
510+
"apache-2.0"
511+
]
500512
}
501513

502514
INPUT_PRICE_PER_MILLION_TOKEN = {

berkeley-function-call-leaderboard/bfcl/model_handler/handler_map.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from bfcl.model_handler.oss_model.deepseek import DeepseekHandler
2+
from bfcl.model_handler.oss_model.empower import EmpowerHandler
23
from bfcl.model_handler.oss_model.gemma import GemmaHandler
34
from bfcl.model_handler.oss_model.glaive import GlaiveHandler
45
from bfcl.model_handler.oss_model.glm import GLMHandler
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
from bfcl.model_handler.oss_model.base_oss_handler import OSSHandler
2+
from bfcl.model_handler.model_style import ModelStyle
3+
import json
4+
from bfcl.model_handler.utils import (
5+
convert_to_tool,
6+
)
7+
from bfcl.model_handler.constant import (
8+
GORILLA_TO_OPENAPI,
9+
)
10+
11+
12+
class EmpowerHandler(OSSHandler):
13+
def __init__(self, model_name, temperature) -> None:
14+
super().__init__(model_name, temperature)
15+
16+
def _preprocess_messages(self, messages):
17+
# remove system message
18+
messages = [
19+
message for message in messages if message['role'] != "system"]
20+
21+
# combine tool responses
22+
result = []
23+
temp_tool_content = None
24+
for message in messages:
25+
if message['role'] == 'tool':
26+
decoded_content = json.loads(message['content'])
27+
if temp_tool_content:
28+
temp_tool_content.append(decoded_content)
29+
else:
30+
temp_tool_content = [decoded_content]
31+
else:
32+
if temp_tool_content:
33+
result.append({
34+
'role': 'tool',
35+
'content': json.dumps(temp_tool_content, indent=2)
36+
})
37+
temp_tool_content = None
38+
result.append(message)
39+
if temp_tool_content:
40+
result.append({
41+
'role': 'tool',
42+
'content': json.dumps(temp_tool_content, indent=2)
43+
})
44+
45+
return result
46+
47+
def _format_prompt(self, messages, functions):
48+
formatted_prompt = "<|begin_of_text|>"
49+
50+
for idx, message in enumerate(self._preprocess_messages(messages)):
51+
if idx == 0:
52+
tools = convert_to_tool(
53+
functions, GORILLA_TO_OPENAPI, ModelStyle.OSSMODEL
54+
)
55+
message['content'] = "In this environment you have access to a set of functions defined in the JSON format you can use to address user's requests, use them if needed.\nFunctions:\n" \
56+
+ json.dumps(tools, indent=2) \
57+
+ "\n\n" \
58+
+ "User Message:\n" \
59+
+ message['content']
60+
else:
61+
if message['role'] == 'tool':
62+
message['role'] = 'user'
63+
message['content'] = '<r>' + message['content']
64+
elif message['role'] == 'user' and not message['content'].startswith('<r>') and not message['content'].startswith('<u>'):
65+
message['content'] = '<u>' + message['content']
66+
67+
formatted_prompt += f"<|start_header_id|>{message['role']}<|end_header_id|>\n\n{message['content']}<|eot_id|>"
68+
69+
formatted_prompt += f"<|start_header_id|>assistant<|end_header_id|>\n\n"
70+
71+
return formatted_prompt
72+
73+
def decode_ast(self, result, language="Python"):
74+
if not result.startswith('<f>'):
75+
return []
76+
77+
# strip the function/conversation tag <f>/<c>
78+
result_stripped = result[3:]
79+
80+
decoded_output = []
81+
for invoked_function in json.loads(result_stripped):
82+
name = invoked_function["name"]
83+
params = invoked_function["arguments"] if "arguments" in invoked_function else {
84+
}
85+
decoded_output.append({name: params})
86+
87+
return decoded_output
88+
89+
def decode_execute(self, result):
90+
execution_list = []
91+
92+
for function_call in self.decode_ast(result):
93+
for key, value in function_call.items():
94+
argument_list = []
95+
for k, v in value.items():
96+
argument_list.append(f'{k}={repr(v)}')
97+
execution_list.append(
98+
f"{key}({','.join(argument_list)})"
99+
)
100+
101+
return execution_list

0 commit comments

Comments
 (0)