Skip to content

Commit d6a26fc

Browse files
byeblackbyeblack
andauthored
fix(rig-integration): implement ToolEmbeddingDyn trait for McpToolAdaptor (#99)
- Added ToolEmbeddingDyn trait for McpToolAdaptor to enable dynamic loading of McpTool - Enhanced logging to record details Co-authored-by: byeblack <[email protected]>
1 parent 910d3b3 commit d6a26fc

File tree

4 files changed

+58
-28
lines changed

4 files changed

+58
-28
lines changed

examples/rig-integration/Cargo.toml

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,4 +23,11 @@ anyhow = "1.0"
2323
serde_json = "1"
2424
serde = { version = "1", features = ["derive"] }
2525
toml = "0.8"
26-
futures = "0.3"
26+
futures = "0.3"
27+
tracing = "0.1"
28+
tracing-subscriber = { version = "0.3", features = [
29+
"env-filter",
30+
"std",
31+
"fmt",
32+
] }
33+
tracing-appender = "0.2"

examples/rig-integration/src/chat.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@ use futures::StreamExt;
22
use rig::{
33
agent::Agent,
44
message::Message,
5-
providers::deepseek::DeepSeekCompletionModel,
6-
streaming::{StreamingChat, StreamingChoice},
5+
streaming::{StreamingChat, StreamingChoice, StreamingCompletionModel},
76
};
87
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
98

10-
pub async fn cli_chatbot(chatbot: Agent<DeepSeekCompletionModel>) -> anyhow::Result<()> {
9+
pub async fn cli_chatbot<M: StreamingCompletionModel>(chatbot: Agent<M>) -> anyhow::Result<()> {
1110
let mut chat_log = vec![];
1211

1312
let mut output = BufWriter::new(tokio::io::stdout());
@@ -27,6 +26,7 @@ pub async fn cli_chatbot(chatbot: Agent<DeepSeekCompletionModel>) -> anyhow::Res
2726
}
2827
match chatbot.stream_chat(input, chat_log.clone()).await {
2928
Ok(mut response) => {
29+
tracing::info!(%input);
3030
chat_log.push(Message::user(input));
3131
stream_output_agent_start(&mut output).await?;
3232
let mut message_buf = String::new();

examples/rig-integration/src/main.rs

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,30 @@
11
use rig::{
22
embeddings::EmbeddingsBuilder,
3-
providers::{
4-
cohere,
5-
deepseek::{self, DEEPSEEK_CHAT},
6-
},
3+
providers::{cohere, deepseek},
74
vector_store::in_memory_store::InMemoryVectorStore,
85
};
6+
use tracing_appender::rolling::{RollingFileAppender, Rotation};
97
pub mod chat;
108
pub mod config;
119
pub mod mcp_adaptor;
10+
1211
#[tokio::main]
1312
async fn main() -> anyhow::Result<()> {
13+
let file_appender = RollingFileAppender::new(
14+
Rotation::DAILY,
15+
"logs",
16+
format!("{}.log", env!("CARGO_CRATE_NAME")),
17+
);
18+
tracing_subscriber::fmt()
19+
.with_env_filter(
20+
tracing_subscriber::EnvFilter::from_default_env()
21+
.add_directive(tracing::Level::INFO.into()),
22+
)
23+
.with_writer(file_appender)
24+
.with_file(false)
25+
.with_ansi(false)
26+
.init();
27+
1428
let config = config::Config::retrieve("config.toml").await?;
1529
let openai_client = {
1630
if let Some(key) = config.deepseek_key {
@@ -27,7 +41,7 @@ async fn main() -> anyhow::Result<()> {
2741
}
2842
};
2943
let mcp_manager = config.mcp.create_manager().await?;
30-
eprintln!(
44+
tracing::info!(
3145
"MCP Manager created, {} servers started",
3246
mcp_manager.clients.len()
3347
);
@@ -39,26 +53,16 @@ async fn main() -> anyhow::Result<()> {
3953
.build()
4054
.await?;
4155
let store = InMemoryVectorStore::from_documents_with_id_f(embeddings, |f| {
42-
eprintln!("store tool {}", f.name);
56+
tracing::info!("store tool {}", f.name);
4357
f.name.clone()
4458
});
4559
let index = store.index(embedding_model);
4660
let dpsk = openai_client
47-
.agent(DEEPSEEK_CHAT)
48-
.context(
49-
r#"You are an assistant here to help the user to do some works.
50-
When you need to use tools, you should select which tool is most appropriate to meet the user's requirement.
51-
Follow these instructions closely.
52-
1. Consider the user's request carefully and identify the core elements of the request.
53-
2. Select which tool among those made available to you is appropriate given the context.
54-
3. This is very important: never perform the operation yourself and never give me the direct result.
55-
Always respond with the name of the tool that should be used and the appropriate inputs
56-
in the following format:
57-
Tool: <tool name>
58-
Inputs: <list of inputs>"#,
59-
)
61+
.agent(deepseek::DEEPSEEK_CHAT)
6062
.dynamic_tools(4, index, tool_set)
6163
.build();
64+
6265
chat::cli_chatbot(dpsk).await?;
66+
6367
Ok(())
6468
}

examples/rig-integration/src/mcp_adaptor.rs

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::collections::HashMap;
22

3-
use rig::tool::{ToolDyn as RigTool, ToolSet};
3+
use rig::tool::{ToolDyn as RigTool, ToolEmbeddingDyn, ToolSet};
44
use rmcp::{
55
RoleClient,
66
model::{CallToolRequestParam, CallToolResult, Tool as McpTool},
@@ -49,13 +49,31 @@ impl RigTool for McpToolAdaptor {
4949
.map_err(rig::tool::ToolError::JsonError)?,
5050
})
5151
.await
52+
.inspect(|result| tracing::info!(?result))
53+
.inspect_err(|error| tracing::error!(%error))
5254
.map_err(|e| rig::tool::ToolError::ToolCallError(Box::new(e)))?;
5355

5456
Ok(convert_mcp_call_tool_result_to_string(call_mcp_tool_result))
5557
})
5658
}
5759
}
5860

61+
impl ToolEmbeddingDyn for McpToolAdaptor {
62+
fn context(&self) -> serde_json::Result<serde_json::Value> {
63+
serde_json::to_value(self.tool.clone())
64+
}
65+
66+
fn embedding_docs(&self) -> Vec<String> {
67+
vec![
68+
self.tool
69+
.description
70+
.as_deref()
71+
.unwrap_or_default()
72+
.to_string(),
73+
]
74+
}
75+
}
76+
5977
pub struct McpManager {
6078
pub clients: HashMap<String, RunningService<RoleClient, ()>>,
6179
}
@@ -72,7 +90,7 @@ impl McpManager {
7290
for result in results {
7391
match result {
7492
Err(e) => {
75-
eprintln!("Failed to get tool set: {:?}", e);
93+
tracing::error!(error = %e, "Failed to get tool set");
7694
}
7795
Ok(tools) => {
7896
tool_set.add_tools(tools);
@@ -89,14 +107,15 @@ pub fn convert_mcp_call_tool_result_to_string(result: CallToolResult) -> String
89107

90108
pub async fn get_tool_set(server: ServerSink) -> anyhow::Result<ToolSet> {
91109
let tools = server.list_all_tools().await?;
92-
let mut tool_set = ToolSet::default();
110+
let mut tool_builder = ToolSet::builder();
93111
for tool in tools {
94-
eprintln!("get tool: {}", tool.name);
112+
tracing::info!("get tool: {}", tool.name);
95113
let adaptor = McpToolAdaptor {
96114
tool: tool.clone(),
97115
server: server.clone(),
98116
};
99-
tool_set.add_tool(adaptor);
117+
tool_builder = tool_builder.dynamic_tool(adaptor);
100118
}
119+
let tool_set = tool_builder.build();
101120
Ok(tool_set)
102121
}

0 commit comments

Comments
 (0)