Skip to content

Commit 6de444c

Browse files
authored
feat(examples): Add func call support (#168)
Some works about simple-chat. 1. fix some bugs. 2. add func call optional. Signed-off-by: jokemanfire <[email protected]>
1 parent 787cc01 commit 6de444c

File tree

8 files changed

+160
-108
lines changed

8 files changed

+160
-108
lines changed

examples/simple-chat-client/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ After configuring the config file, you can run the example:
99
```bash
1010
./simple_chat --help # show help info
1111
./simple_chat config > config.toml # output default config to file
12-
./simple_chat chat --config my_config.toml # start chat with specified config
13-
./simple_chat chat --config my_config.toml --model gpt-4o-mini # start chat with specified model
12+
./simple_chat --config my_config.toml chat # start chat with specified config
13+
./simple_chat --config my_config.toml --model gpt-4o-mini chat # start chat with specified model
1414
```
1515

examples/simple-chat-client/src/bin/simple_chat.rs

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -95,33 +95,42 @@ async fn main() -> Result<()> {
9595
.unwrap_or_else(|| "gpt-4o-mini".to_string()),
9696
);
9797

98-
// build system prompt
99-
let mut system_prompt =
100-
"you are a assistant, you can help user to complete various tasks. you have the following tools to use:\n".to_string();
101-
102-
// add tool info to system prompt
103-
for tool in session.get_tools() {
104-
system_prompt.push_str(&format!(
105-
"\ntool name: {}\ndescription: {}\nparameters: {}\n",
106-
tool.name(),
107-
tool.description(),
108-
serde_json::to_string_pretty(&tool.parameters())
109-
.expect("failed to serialize tool parameters")
110-
));
111-
}
98+
let support_tool = config.support_tool.unwrap_or(true);
99+
let mut system_prompt;
100+
// if not support tool call, add tool call format guidance
101+
if !support_tool {
102+
// build system prompt
103+
system_prompt =
104+
"you are a assistant, you can help user to complete various tasks. you have the following tools to use:\n".to_string();
105+
106+
// add tool info to system prompt
107+
for tool in session.get_tools() {
108+
system_prompt.push_str(&format!(
109+
"\ntool name: {}\ndescription: {}\nparameters: {}\n",
110+
tool.name(),
111+
tool.description(),
112+
serde_json::to_string_pretty(&tool.parameters())
113+
.expect("failed to serialize tool parameters")
114+
));
115+
}
112116

113-
// add tool call format guidance
114-
system_prompt.push_str(
115-
"\nif you need to call tool, please use the following format:\n\
116-
Tool: <tool name>\n\
117-
Inputs: <inputs>\n",
118-
);
117+
// add tool call format guidance
118+
system_prompt.push_str(
119+
"\nif you need to call tool, please use the following format:\n\
120+
Tool: <tool name>\n\
121+
Inputs: <inputs>\n",
122+
);
123+
println!("system prompt: {}", system_prompt);
124+
} else {
125+
system_prompt =
126+
"you are a assistant, you can help user to complete various tasks.".to_string();
127+
}
119128

120129
// add system prompt
121130
session.add_system_prompt(system_prompt);
122131

123132
// start chat
124-
session.chat().await?;
133+
session.chat(support_tool).await?;
125134
}
126135
}
127136

examples/simple-chat-client/src/chat.rs

Lines changed: 102 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@ use std::{
44
};
55

66
use anyhow::Result;
7+
use serde_json;
78

89
use crate::{
910
client::ChatClient,
10-
model::{CompletionRequest, Message},
11+
model::{CompletionRequest, Message, ToolFunction},
1112
tool::{Tool as ToolTrait, ToolSet},
1213
};
1314

@@ -36,7 +37,84 @@ impl ChatSession {
3637
self.tool_set.tools()
3738
}
3839

39-
pub async fn chat(&mut self) -> Result<()> {
40+
pub async fn analyze_tool_call(&mut self, response: &Message) {
41+
let mut tool_calls_func = Vec::new();
42+
if let Some(tool_calls) = response.tool_calls.as_ref() {
43+
for tool_call in tool_calls {
44+
if tool_call._type == "function" {
45+
tool_calls_func.push(tool_call.function.clone());
46+
}
47+
}
48+
} else {
49+
// check if message contains tool call
50+
if response.content.contains("Tool:") {
51+
let lines: Vec<&str> = response.content.split('\n').collect();
52+
// simple parse tool call
53+
let mut tool_name = None;
54+
let mut args_text = Vec::new();
55+
let mut parsing_args = false;
56+
57+
for line in lines {
58+
if line.starts_with("Tool:") {
59+
tool_name = line.strip_prefix("Tool:").map(|s| s.trim().to_string());
60+
parsing_args = false;
61+
} else if line.starts_with("Inputs:") {
62+
parsing_args = true;
63+
} else if parsing_args {
64+
args_text.push(line.trim());
65+
}
66+
}
67+
if let Some(name) = tool_name {
68+
tool_calls_func.push(ToolFunction {
69+
name,
70+
arguments: args_text.join("\n"),
71+
});
72+
}
73+
}
74+
}
75+
// call tool
76+
for tool_call in tool_calls_func {
77+
println!("tool call: {:?}", tool_call);
78+
let tool = self.tool_set.get_tool(&tool_call.name);
79+
if let Some(tool) = tool {
80+
// call tool
81+
let args = serde_json::from_str::<serde_json::Value>(&tool_call.arguments)
82+
.unwrap_or_default();
83+
match tool.call(args).await {
84+
Ok(result) => {
85+
if result.is_error.is_some_and(|b| b) {
86+
self.messages
87+
.push(Message::user("tool call failed, mcp call error"));
88+
} else {
89+
result.content.iter().for_each(|content| {
90+
if let Some(content_text) = content.as_text() {
91+
let json_result = serde_json::from_str::<serde_json::Value>(
92+
&content_text.text,
93+
)
94+
.unwrap_or_default();
95+
let pretty_result =
96+
serde_json::to_string_pretty(&json_result).unwrap();
97+
println!("call tool result: {}", pretty_result);
98+
self.messages.push(Message::user(format!(
99+
"call tool result: {}",
100+
pretty_result
101+
)));
102+
}
103+
});
104+
}
105+
}
106+
Err(e) => {
107+
println!("tool call failed: {}", e);
108+
self.messages
109+
.push(Message::user(format!("tool call failed: {}", e)));
110+
}
111+
}
112+
} else {
113+
println!("tool not found: {}", tool_call.name);
114+
}
115+
}
116+
}
117+
pub async fn chat(&mut self, support_tool: bool) -> Result<()> {
40118
println!("welcome to use simple chat client, use 'exit' to quit");
41119

42120
loop {
@@ -56,20 +134,23 @@ impl ChatSession {
56134
}
57135

58136
self.messages.push(Message::user(&input));
59-
60-
// prepare tool list
61-
let tools = self.tool_set.tools();
62-
let tool_definitions = if !tools.is_empty() {
63-
Some(
64-
tools
65-
.iter()
66-
.map(|tool| crate::model::Tool {
67-
name: tool.name(),
68-
description: tool.description(),
69-
parameters: tool.parameters(),
70-
})
71-
.collect(),
72-
)
137+
let tool_definitions = if support_tool {
138+
// prepare tool list
139+
let tools = self.tool_set.tools();
140+
if !tools.is_empty() {
141+
Some(
142+
tools
143+
.iter()
144+
.map(|tool| crate::model::Tool {
145+
name: tool.name(),
146+
description: tool.description(),
147+
parameters: tool.parameters(),
148+
})
149+
.collect(),
150+
)
151+
} else {
152+
None
153+
}
73154
} else {
74155
None
75156
};
@@ -84,65 +165,11 @@ impl ChatSession {
84165

85166
// send request
86167
let response = self.client.complete(request).await?;
87-
88-
if let Some(choice) = response.choices.first() {
89-
println!("AI: {}", choice.message.content);
90-
self.messages.push(choice.message.clone());
91-
92-
// check if message contains tool call
93-
if choice.message.content.contains("Tool:") {
94-
let lines: Vec<&str> = choice.message.content.split('\n').collect();
95-
96-
// simple parse tool call
97-
let mut tool_name = None;
98-
let mut args_text = Vec::new();
99-
let mut parsing_args = false;
100-
101-
for line in lines {
102-
if line.starts_with("Tool:") {
103-
tool_name = line.strip_prefix("Tool:").map(|s| s.trim().to_string());
104-
parsing_args = false;
105-
} else if line.starts_with("Inputs:") {
106-
parsing_args = true;
107-
} else if parsing_args {
108-
args_text.push(line.trim());
109-
}
110-
}
111-
112-
if let Some(name) = tool_name {
113-
if let Some(tool) = self.tool_set.get_tool(&name) {
114-
println!("calling tool: {}", name);
115-
116-
// simple handle args
117-
let args_str = args_text.join("\n");
118-
let args = match serde_json::from_str(&args_str) {
119-
Ok(v) => v,
120-
Err(_) => {
121-
// try to handle args as string
122-
serde_json::Value::String(args_str)
123-
}
124-
};
125-
126-
// call tool
127-
match tool.call(args).await {
128-
Ok(result) => {
129-
println!("tool result: {}", result);
130-
131-
// add tool result to dialog
132-
self.messages.push(Message::user(result));
133-
}
134-
Err(e) => {
135-
println!("tool call failed: {}", e);
136-
self.messages
137-
.push(Message::user(format!("tool call failed: {}", e)));
138-
}
139-
}
140-
} else {
141-
println!("tool not found: {}", name);
142-
}
143-
}
144-
}
145-
}
168+
// get choice
169+
let choice = response.choices.first().unwrap();
170+
println!("AI > {}", choice.message.content);
171+
// analyze tool call
172+
self.analyze_tool_call(&choice.message).await;
146173
}
147174

148175
Ok(())

examples/simple-chat-client/src/client.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,11 @@ impl ChatClient for OpenAIClient {
5858
println!("API error: {}", error_text);
5959
return Err(anyhow::anyhow!("API Error: {}", error_text));
6060
}
61-
62-
let completion: CompletionResponse = response.json().await?;
61+
let text_data = response.text().await?;
62+
println!("Received response: {}", text_data);
63+
let completion: CompletionResponse = serde_json::from_str(&text_data)
64+
.map_err(anyhow::Error::from)
65+
.unwrap();
6366
Ok(completion)
6467
}
6568
}

examples/simple-chat-client/src/config.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ pub struct Config {
1111
pub mcp: Option<McpConfig>,
1212
pub model_name: Option<String>,
1313
pub proxy: Option<bool>,
14+
pub support_tool: Option<bool>,
1415
}
1516

1617
#[derive(Debug, Serialize, Deserialize)]

examples/simple-chat-client/src/config.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ openai_key = "key"
22
chat_url = "url"
33
model_name = "model_name"
44
proxy = false
5+
support_tool = true # if support tool call
56

67
[mcp]
78
[[mcp.server]]

examples/simple-chat-client/src/model.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,27 +4,32 @@ use serde::{Deserialize, Serialize};
44
pub struct Message {
55
pub role: String,
66
pub content: String,
7+
#[serde(skip_serializing_if = "Option::is_none")]
8+
pub tool_calls: Option<Vec<ToolCall>>,
79
}
810

911
impl Message {
1012
pub fn system(content: impl ToString) -> Self {
1113
Self {
1214
role: "system".to_string(),
1315
content: content.to_string(),
16+
tool_calls: None,
1417
}
1518
}
1619

1720
pub fn user(content: impl ToString) -> Self {
1821
Self {
1922
role: "user".to_string(),
2023
content: content.to_string(),
24+
tool_calls: None,
2125
}
2226
}
2327

2428
pub fn assistant(content: impl ToString) -> Self {
2529
Self {
2630
role: "assistant".to_string(),
2731
content: content.to_string(),
32+
tool_calls: None,
2833
}
2934
}
3035
}
@@ -62,10 +67,17 @@ pub struct Choice {
6267
pub finish_reason: String,
6368
}
6469

65-
#[derive(Debug, Serialize, Deserialize)]
70+
#[derive(Debug, Serialize, Deserialize, Clone)]
6671
pub struct ToolCall {
72+
pub id: String,
73+
#[serde(rename = "type")]
74+
pub _type: String,
75+
pub function: ToolFunction,
76+
}
77+
#[derive(Debug, Serialize, Deserialize, Clone)]
78+
pub struct ToolFunction {
6779
pub name: String,
68-
pub arguments: serde_json::Value,
80+
pub arguments: String,
6981
}
7082

7183
#[derive(Debug, Serialize, Deserialize)]

0 commit comments

Comments
 (0)