@@ -4,10 +4,11 @@ use std::{
44} ;
55
66use anyhow:: Result ;
7+ use serde_json;
78
89use crate :: {
910 client:: ChatClient ,
10- model:: { CompletionRequest , Message } ,
11+ model:: { CompletionRequest , Message , ToolFunction } ,
1112 tool:: { Tool as ToolTrait , ToolSet } ,
1213} ;
1314
@@ -36,7 +37,84 @@ impl ChatSession {
3637 self . tool_set . tools ( )
3738 }
3839
39- pub async fn chat ( & mut self ) -> Result < ( ) > {
40+ pub async fn analyze_tool_call ( & mut self , response : & Message ) {
41+ let mut tool_calls_func = Vec :: new ( ) ;
42+ if let Some ( tool_calls) = response. tool_calls . as_ref ( ) {
43+ for tool_call in tool_calls {
44+ if tool_call. _type == "function" {
45+ tool_calls_func. push ( tool_call. function . clone ( ) ) ;
46+ }
47+ }
48+ } else {
49+ // check if message contains tool call
50+ if response. content . contains ( "Tool:" ) {
51+ let lines: Vec < & str > = response. content . split ( '\n' ) . collect ( ) ;
52+ // simple parse tool call
53+ let mut tool_name = None ;
54+ let mut args_text = Vec :: new ( ) ;
55+ let mut parsing_args = false ;
56+
57+ for line in lines {
58+ if line. starts_with ( "Tool:" ) {
59+ tool_name = line. strip_prefix ( "Tool:" ) . map ( |s| s. trim ( ) . to_string ( ) ) ;
60+ parsing_args = false ;
61+ } else if line. starts_with ( "Inputs:" ) {
62+ parsing_args = true ;
63+ } else if parsing_args {
64+ args_text. push ( line. trim ( ) ) ;
65+ }
66+ }
67+ if let Some ( name) = tool_name {
68+ tool_calls_func. push ( ToolFunction {
69+ name,
70+ arguments : args_text. join ( "\n " ) ,
71+ } ) ;
72+ }
73+ }
74+ }
75+ // call tool
76+ for tool_call in tool_calls_func {
77+ println ! ( "tool call: {:?}" , tool_call) ;
78+ let tool = self . tool_set . get_tool ( & tool_call. name ) ;
79+ if let Some ( tool) = tool {
80+ // call tool
81+ let args = serde_json:: from_str :: < serde_json:: Value > ( & tool_call. arguments )
82+ . unwrap_or_default ( ) ;
83+ match tool. call ( args) . await {
84+ Ok ( result) => {
85+ if result. is_error . is_some_and ( |b| b) {
86+ self . messages
87+ . push ( Message :: user ( "tool call failed, mcp call error" ) ) ;
88+ } else {
89+ result. content . iter ( ) . for_each ( |content| {
90+ if let Some ( content_text) = content. as_text ( ) {
91+ let json_result = serde_json:: from_str :: < serde_json:: Value > (
92+ & content_text. text ,
93+ )
94+ . unwrap_or_default ( ) ;
95+ let pretty_result =
96+ serde_json:: to_string_pretty ( & json_result) . unwrap ( ) ;
97+ println ! ( "call tool result: {}" , pretty_result) ;
98+ self . messages . push ( Message :: user ( format ! (
99+ "call tool result: {}" ,
100+ pretty_result
101+ ) ) ) ;
102+ }
103+ } ) ;
104+ }
105+ }
106+ Err ( e) => {
107+ println ! ( "tool call failed: {}" , e) ;
108+ self . messages
109+ . push ( Message :: user ( format ! ( "tool call failed: {}" , e) ) ) ;
110+ }
111+ }
112+ } else {
113+ println ! ( "tool not found: {}" , tool_call. name) ;
114+ }
115+ }
116+ }
117+ pub async fn chat ( & mut self , support_tool : bool ) -> Result < ( ) > {
40118 println ! ( "welcome to use simple chat client, use 'exit' to quit" ) ;
41119
42120 loop {
@@ -56,20 +134,23 @@ impl ChatSession {
56134 }
57135
58136 self . messages . push ( Message :: user ( & input) ) ;
59-
60- // prepare tool list
61- let tools = self . tool_set . tools ( ) ;
62- let tool_definitions = if !tools. is_empty ( ) {
63- Some (
64- tools
65- . iter ( )
66- . map ( |tool| crate :: model:: Tool {
67- name : tool. name ( ) ,
68- description : tool. description ( ) ,
69- parameters : tool. parameters ( ) ,
70- } )
71- . collect ( ) ,
72- )
137+ let tool_definitions = if support_tool {
138+ // prepare tool list
139+ let tools = self . tool_set . tools ( ) ;
140+ if !tools. is_empty ( ) {
141+ Some (
142+ tools
143+ . iter ( )
144+ . map ( |tool| crate :: model:: Tool {
145+ name : tool. name ( ) ,
146+ description : tool. description ( ) ,
147+ parameters : tool. parameters ( ) ,
148+ } )
149+ . collect ( ) ,
150+ )
151+ } else {
152+ None
153+ }
73154 } else {
74155 None
75156 } ;
@@ -84,65 +165,11 @@ impl ChatSession {
84165
85166 // send request
86167 let response = self . client . complete ( request) . await ?;
87-
88- if let Some ( choice) = response. choices . first ( ) {
89- println ! ( "AI: {}" , choice. message. content) ;
90- self . messages . push ( choice. message . clone ( ) ) ;
91-
92- // check if message contains tool call
93- if choice. message . content . contains ( "Tool:" ) {
94- let lines: Vec < & str > = choice. message . content . split ( '\n' ) . collect ( ) ;
95-
96- // simple parse tool call
97- let mut tool_name = None ;
98- let mut args_text = Vec :: new ( ) ;
99- let mut parsing_args = false ;
100-
101- for line in lines {
102- if line. starts_with ( "Tool:" ) {
103- tool_name = line. strip_prefix ( "Tool:" ) . map ( |s| s. trim ( ) . to_string ( ) ) ;
104- parsing_args = false ;
105- } else if line. starts_with ( "Inputs:" ) {
106- parsing_args = true ;
107- } else if parsing_args {
108- args_text. push ( line. trim ( ) ) ;
109- }
110- }
111-
112- if let Some ( name) = tool_name {
113- if let Some ( tool) = self . tool_set . get_tool ( & name) {
114- println ! ( "calling tool: {}" , name) ;
115-
116- // simple handle args
117- let args_str = args_text. join ( "\n " ) ;
118- let args = match serde_json:: from_str ( & args_str) {
119- Ok ( v) => v,
120- Err ( _) => {
121- // try to handle args as string
122- serde_json:: Value :: String ( args_str)
123- }
124- } ;
125-
126- // call tool
127- match tool. call ( args) . await {
128- Ok ( result) => {
129- println ! ( "tool result: {}" , result) ;
130-
131- // add tool result to dialog
132- self . messages . push ( Message :: user ( result) ) ;
133- }
134- Err ( e) => {
135- println ! ( "tool call failed: {}" , e) ;
136- self . messages
137- . push ( Message :: user ( format ! ( "tool call failed: {}" , e) ) ) ;
138- }
139- }
140- } else {
141- println ! ( "tool not found: {}" , name) ;
142- }
143- }
144- }
145- }
168+ // get choice
169+ let choice = response. choices . first ( ) . unwrap ( ) ;
170+ println ! ( "AI > {}" , choice. message. content) ;
171+ // analyze tool call
172+ self . analyze_tool_call ( & choice. message ) . await ;
146173 }
147174
148175 Ok ( ( ) )
0 commit comments