|
1 |
| -use std::sync::Arc; |
| 1 | +use std::{process::exit, sync::Arc}; |
2 | 2 |
|
3 | 3 | use anyhow::Result;
|
| 4 | +use clap::{Parser, Subcommand}; |
4 | 5 | use simple_chat_client::{
|
5 | 6 | chat::ChatSession,
|
6 | 7 | client::OpenAIClient,
|
7 | 8 | config::Config,
|
8 | 9 | tool::{Tool, ToolSet, get_mcp_tools},
|
9 | 10 | };
|
10 | 11 |
|
11 |
| -//default config path |
12 |
| -const DEFAULT_CONFIG_PATH: &str = "/etc/simple-chat-client/config.toml"; |
| 12 | +#[derive(Parser)] |
| 13 | +#[command(author, version, about = "Simple Chat Client")] |
| 14 | +struct Cli { |
| 15 | + /// Config file path |
| 16 | + #[arg(short, long, value_name = "FILE")] |
| 17 | + config: Option<String>, |
| 18 | + |
| 19 | + #[command(subcommand)] |
| 20 | + command: Commands, |
| 21 | +} |
| 22 | + |
| 23 | +#[derive(Subcommand)] |
| 24 | +enum Commands { |
| 25 | + /// Output default config template |
| 26 | + Config, |
| 27 | + |
| 28 | + /// Start chat |
| 29 | + Chat { |
| 30 | + /// Specify the model name |
| 31 | + #[arg(short, long)] |
| 32 | + model: Option<String>, |
| 33 | + }, |
| 34 | +} |
13 | 35 |
|
14 | 36 | #[tokio::main]
|
15 | 37 | async fn main() -> Result<()> {
|
16 |
| - // load config |
17 |
| - let config = Config::load(DEFAULT_CONFIG_PATH).await?; |
18 |
| - |
19 |
| - // create openai client |
20 |
| - let api_key = config |
21 |
| - .openai_key |
22 |
| - .clone() |
23 |
| - .unwrap_or_else(|| std::env::var("OPENAI_API_KEY").expect("need set api key")); |
24 |
| - let url = config.chat_url.clone(); |
25 |
| - println!("url is {:?}", url); |
26 |
| - let openai_client = Arc::new(OpenAIClient::new(api_key, url, config.proxy)); |
27 |
| - |
28 |
| - // create tool set |
29 |
| - let mut tool_set = ToolSet::default(); |
30 |
| - |
31 |
| - // load mcp |
32 |
| - if config.mcp.is_some() { |
33 |
| - let mcp_clients = config.create_mcp_clients().await?; |
34 |
| - |
35 |
| - for (name, client) in mcp_clients { |
36 |
| - println!("loading mcp tools: {}", name); |
37 |
| - let server = client.peer().clone(); |
38 |
| - let tools = get_mcp_tools(server).await?; |
39 |
| - |
40 |
| - for tool in tools { |
41 |
| - println!("adding tool: {}", tool.name()); |
42 |
| - tool_set.add_tool(tool); |
43 |
| - } |
| 38 | + let cli = Cli::parse(); |
| 39 | + |
| 40 | + match cli.command { |
| 41 | + Commands::Config => { |
| 42 | + println!("{}", include_str!("../config.toml")); |
| 43 | + return Ok(()); |
44 | 44 | }
|
45 |
| - } |
| 45 | + Commands::Chat { model } => { |
| 46 | + // load config |
| 47 | + let config_path = cli.config; |
| 48 | + let mut config = match config_path { |
| 49 | + Some(path) => Config::load(&path).await?, |
| 50 | + None => { |
| 51 | + println!("No config file provided, using default config"); |
| 52 | + exit(-1); |
| 53 | + } |
| 54 | + }; |
46 | 55 |
|
47 |
| - // create chat session |
48 |
| - let mut session = ChatSession::new( |
49 |
| - openai_client, |
50 |
| - tool_set, |
51 |
| - config |
52 |
| - .model_name |
53 |
| - .unwrap_or_else(|| "gpt-4o-mini".to_string()), |
54 |
| - ); |
55 |
| - |
56 |
| - // build system prompt with tool info |
57 |
| - let mut system_prompt = |
58 |
| - "you are a assistant, you can help user to complete various tasks. you have the following tools to use:\n".to_string(); |
59 |
| - |
60 |
| - // add tool info to system prompt |
61 |
| - for tool in session.get_tools() { |
62 |
| - system_prompt.push_str(&format!( |
63 |
| - "\ntool name: {}\ndescription: {}\nparameters: {}\n", |
64 |
| - tool.name(), |
65 |
| - tool.description(), |
66 |
| - serde_json::to_string_pretty(&tool.parameters()).unwrap_or_default() |
67 |
| - )); |
68 |
| - } |
| 56 | + // if command line specify model, override config file setting |
| 57 | + if let Some(model_name) = model { |
| 58 | + config.model_name = Some(model_name); |
| 59 | + } |
| 60 | + |
| 61 | + // create openai client |
| 62 | + let api_key = config |
| 63 | + .openai_key |
| 64 | + .clone() |
| 65 | + .unwrap_or_else(|| std::env::var("OPENAI_API_KEY").expect("need set api key")); |
| 66 | + let url = config.chat_url.clone(); |
| 67 | + println!("use api address: {:?}", url); |
| 68 | + let openai_client = Arc::new(OpenAIClient::new(api_key, url, config.proxy)); |
69 | 69 |
|
70 |
| - // add tool call format guidance |
71 |
| - system_prompt.push_str( |
72 |
| - "\nif you need to call tool, please use the following format:\n\ |
73 |
| - Tool: <tool name>\n\ |
74 |
| - Inputs: <inputs>\n", |
75 |
| - ); |
| 70 | + // create tool set |
| 71 | + let mut tool_set = ToolSet::default(); |
76 | 72 |
|
77 |
| - // add system prompt |
78 |
| - session.add_system_prompt(system_prompt); |
| 73 | + // load MCP |
| 74 | + if config.mcp.is_some() { |
| 75 | + let mcp_clients = config.create_mcp_clients().await?; |
79 | 76 |
|
80 |
| - // start chat |
81 |
| - session.chat().await?; |
| 77 | + for (name, client) in mcp_clients { |
| 78 | + println!("load MCP tool: {}", name); |
| 79 | + let server = client.peer().clone(); |
| 80 | + let tools = get_mcp_tools(server).await?; |
| 81 | + |
| 82 | + for tool in tools { |
| 83 | + println!("add tool: {}", tool.name()); |
| 84 | + tool_set.add_tool(tool); |
| 85 | + } |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + // create chat session |
| 90 | + let mut session = ChatSession::new( |
| 91 | + openai_client, |
| 92 | + tool_set, |
| 93 | + config |
| 94 | + .model_name |
| 95 | + .unwrap_or_else(|| "gpt-4o-mini".to_string()), |
| 96 | + ); |
| 97 | + |
| 98 | + // build system prompt |
| 99 | + let mut system_prompt = |
| 100 | + "you are a assistant, you can help user to complete various tasks. you have the following tools to use:\n".to_string(); |
| 101 | + |
| 102 | + // add tool info to system prompt |
| 103 | + for tool in session.get_tools() { |
| 104 | + system_prompt.push_str(&format!( |
| 105 | + "\ntool name: {}\ndescription: {}\nparameters: {}\n", |
| 106 | + tool.name(), |
| 107 | + tool.description(), |
| 108 | + serde_json::to_string_pretty(&tool.parameters()) |
| 109 | + .expect("failed to serialize tool parameters") |
| 110 | + )); |
| 111 | + } |
| 112 | + |
| 113 | + // add tool call format guidance |
| 114 | + system_prompt.push_str( |
| 115 | + "\nif you need to call tool, please use the following format:\n\ |
| 116 | + Tool: <tool name>\n\ |
| 117 | + Inputs: <inputs>\n", |
| 118 | + ); |
| 119 | + |
| 120 | + // add system prompt |
| 121 | + session.add_system_prompt(system_prompt); |
| 122 | + |
| 123 | + // start chat |
| 124 | + session.chat().await?; |
| 125 | + } |
| 126 | + } |
82 | 127 |
|
83 | 128 | Ok(())
|
84 | 129 | }
|
0 commit comments