Skip to content

Commit 10b9573

Browse files
author
Arjun Balaji
committed
RAG-MCP fully working with accumalation based on recommendations instead of history
1 parent fab02d4 commit 10b9573

File tree

1 file changed

+114
-99
lines changed

1 file changed

+114
-99
lines changed

crates/chat-cli/src/cli/chat/conversation.rs

Lines changed: 114 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,7 @@ impl ConversationState {
332332
for tool in tool_list {
333333
let Tool::ToolSpecification(spec) = tool;
334334
tool_docs.push_str(&format!(
335-
"TOOL_START\nname:{}\nserver:{}\ndescription:{}\nschema:{}\nTOOL_END\n\n",
335+
"TOOL_START|||name:{}|||server:{}|||description:{}|||schema:{}|||TOOL_END\n\n",
336336
spec.name,
337337
server_name,
338338
spec.description,
@@ -384,49 +384,41 @@ impl ConversationState {
384384
return enhanced_tools;
385385
}
386386

387-
let lines: Vec<&str> = rag_context.lines().collect();
388-
let mut tool_name = String::new();
389-
let mut server_name = String::new();
390-
let mut description = String::new();
391-
let mut schema_json = String::new();
392-
393-
for line in lines {
387+
// Split by lines and process each tool entry
388+
for line in rag_context.lines() {
394389
let line = line.trim();
390+
if line.is_empty() {
391+
continue;
392+
}
395393

396-
if line.starts_with("name:") {
397-
// If we encounter a new name and we already have tool data, save the current tool
398-
if !tool_name.is_empty() {
399-
self.add_tool_to_map(
400-
&mut enhanced_tools,
401-
&tool_name,
402-
&server_name,
403-
&description,
404-
&schema_json,
405-
);
394+
// Split by ||| and process each tool
395+
let parts: Vec<&str> = line.split("|||").collect();
396+
if parts.len() >= 4 {
397+
let mut tool_name = "";
398+
let mut server_name = "";
399+
let mut description = "";
400+
let mut schema_json = "";
401+
402+
// Parse each part
403+
for part in &parts {
404+
let part = part.trim();
405+
if let Some(name) = part.strip_prefix("name:") {
406+
tool_name = name.trim();
407+
} else if let Some(server) = part.strip_prefix("server:") {
408+
server_name = server.trim();
409+
} else if let Some(desc) = part.strip_prefix("description:") {
410+
description = desc.trim();
411+
} else if let Some(schema) = part.strip_prefix("schema:") {
412+
schema_json = schema.trim();
413+
}
406414
}
407415

408-
// Start new tool
409-
tool_name = line.strip_prefix("name:").unwrap_or("").trim().to_string();
410-
server_name.clear();
411-
description.clear();
412-
schema_json.clear();
413-
} else if line.starts_with("server:") {
414-
server_name = line.strip_prefix("server:").unwrap_or("").trim().to_string();
415-
} else if line.starts_with("description:") {
416-
description = line.strip_prefix("description:").unwrap_or("").trim().to_string();
417-
} else if line.starts_with("schema:") {
418-
schema_json = line.strip_prefix("schema:").unwrap_or("").trim().to_string();
416+
if !tool_name.is_empty() && !server_name.is_empty() {
417+
self.add_tool_to_map(&mut enhanced_tools, tool_name, server_name, description, schema_json);
418+
}
419419
}
420420
}
421421

422-
self.add_tool_to_map(
423-
&mut enhanced_tools,
424-
&tool_name,
425-
&server_name,
426-
&description,
427-
&schema_json,
428-
);
429-
430422
enhanced_tools
431423
}
432424

@@ -522,28 +514,30 @@ impl ConversationState {
522514
.or_insert(vec![tool]);
523515
acc
524516
});
525-
let knowledge_store = KnowledgeStore::get_async_instance().await;
526-
if let Ok(mut store) = knowledge_store.try_lock() {
527-
let rag_mcp_uuid = KnowledgeStore::initialize_reserved_knowledge_bases();
528-
match self.get_mcp_tool_descriptions(rag_mcp_uuid).await {
529-
Ok(docs_file_path) => {
530-
match store
531-
.sync_mcp_tools_knowledge_base(&docs_file_path, Some(rag_mcp_uuid))
532-
.await
533-
{
534-
Ok(_) => {
535-
// Success - clean up temp file
536-
// std::fs::remove_file(&docs_file_path).ok();
537-
},
538-
Err(e) => {
539-
warn!("Failed to sync MCP tools knowledge base: {}", e);
540-
std::fs::remove_file(&docs_file_path).ok();
541-
},
542-
}
543-
},
544-
Err(e) => {
545-
warn!("Failed to create MCP tools documentation: {}", e);
546-
},
517+
if needs_update {
518+
let knowledge_store = KnowledgeStore::get_async_instance().await;
519+
if let Ok(mut store) = knowledge_store.try_lock() {
520+
let rag_mcp_uuid = KnowledgeStore::initialize_reserved_knowledge_bases();
521+
match self.get_mcp_tool_descriptions(rag_mcp_uuid).await {
522+
Ok(docs_file_path) => {
523+
match store
524+
.sync_mcp_tools_knowledge_base(&docs_file_path, Some(rag_mcp_uuid))
525+
.await
526+
{
527+
Ok(_) => {
528+
// Success - clean up temp file
529+
// std::fs::remove_file(&docs_file_path).ok();
530+
},
531+
Err(e) => {
532+
warn!("Failed to sync MCP tools knowledge base: {}", e);
533+
std::fs::remove_file(&docs_file_path).ok();
534+
},
535+
}
536+
},
537+
Err(e) => {
538+
warn!("Failed to create MCP tools documentation: {}", e);
539+
},
540+
}
547541
}
548542
}
549543
self.tool_manager.has_new_stuff.store(false, Ordering::Release);
@@ -561,6 +555,7 @@ impl ConversationState {
561555
run_perprompt_hooks: bool,
562556
output: &mut impl Write,
563557
) -> Result<BackendConversationState<'_>, ChatError> {
558+
let needs_update = self.tool_manager.has_new_stuff.load(Ordering::Acquire);
564559
self.update_state(false).await;
565560
self.enforce_conversation_invariants();
566561

@@ -583,48 +578,15 @@ impl ConversationState {
583578

584579
let (context_messages, dropped_context_files) = self.context_messages(os, agent_spawn_context).await;
585580

586-
// add top k mcp tools to description consisiting of just native
587-
// ensures not to remove MCP tool descriptions that have been previously used
588-
if self.enhanced_tools_cache.is_none() {
589-
let mut filtered_tools = self.tools.clone();
590-
filtered_tools.retain(|origin, _| matches!(origin, ToolOrigin::Native));
591-
self.enhanced_tools_cache = Some(filtered_tools);
581+
// Initialize or update enhanced tools cache
582+
if self.enhanced_tools_cache.is_none() || needs_update {
583+
self.initialize_enhanced_tools_cache();
592584
}
593585

594-
let mut mcp_tool_hashmap = self.enhanced_tools_cache.clone().unwrap();
595-
let enhanced_tools = match self.next_message.as_ref().and_then(|msg| msg.prompt()) {
596-
Some(user_query) => {
597-
// gets back top k relevant tools in string format
598-
let mcp_tools_string = self
599-
.build_rag_enhanced_context(os, &user_query.to_string())
600-
.await
601-
.unwrap_or_default();
602-
603-
// converts the strings into usable self.tools hashmap format
604-
// TODO: add optimization --> only persist the tool description if it was actually used.
605-
let rag_mcp_tools = self.convert_rag_to_tools(&mcp_tools_string);
606-
607-
if let Some((origin, new_tools)) = rag_mcp_tools.into_iter().next() {
608-
let existing_tools = mcp_tool_hashmap.entry(origin).or_insert_with(Vec::new);
609-
610-
// Only add tools that don't already exist
611-
for new_tool in new_tools {
612-
let Tool::ToolSpecification(new_spec) = &new_tool;
613-
let already_exists = existing_tools.iter().any(|existing_tool| {
614-
matches!(existing_tool, Tool::ToolSpecification(existing_spec) if existing_spec.name == new_spec.name)
615-
});
616-
617-
if !already_exists {
618-
existing_tools.push(new_tool);
619-
}
620-
}
621-
}
622-
623-
mcp_tool_hashmap
624-
},
625-
None => mcp_tool_hashmap,
626-
};
627-
self.enhanced_tools_cache = Some(enhanced_tools);
586+
// Enhance tools with RAG if we have a user query
587+
if let Some(user_query) = self.next_message.as_ref().and_then(|msg| msg.prompt()) {
588+
self.enhance_tools_with_rag(os, &user_query.to_string()).await?;
589+
}
628590

629591
Ok(BackendConversationState {
630592
conversation_id: self.conversation_id.as_str(),
@@ -639,6 +601,59 @@ impl ConversationState {
639601
})
640602
}
641603

604+
fn initialize_enhanced_tools_cache(&mut self) {
605+
let mut filtered_tools = HashMap::new();
606+
607+
// Add all native tools
608+
for (origin, tools) in &self.tools {
609+
if matches!(origin, ToolOrigin::Native) {
610+
filtered_tools.insert(origin.clone(), tools.clone());
611+
}
612+
}
613+
614+
self.enhanced_tools_cache = Some(filtered_tools);
615+
}
616+
617+
// Helper method to enhance tools with RAG, modifying the cache in place
618+
async fn enhance_tools_with_rag(&mut self, os: &Os, user_query: &str) -> Result<(), ChatError> {
619+
// Get RAG MCP tools in string form
620+
let mcp_tools_string = self
621+
.build_rag_enhanced_context(os, user_query)
622+
.await
623+
.unwrap_or_default();
624+
625+
// convert string to usable hashmap form
626+
if !mcp_tools_string.is_empty() {
627+
let rag_mcp_tools = self.convert_rag_to_tools(&mcp_tools_string);
628+
629+
// Build a global set of already cached tool names for deduplication
630+
let mut existing_tool_names: std::collections::HashSet<String> = std::collections::HashSet::new();
631+
if let Some(ref cache) = self.enhanced_tools_cache {
632+
for tools in cache.values() {
633+
for tool in tools {
634+
let Tool::ToolSpecification(spec) = tool;
635+
existing_tool_names.insert(spec.name.clone());
636+
}
637+
}
638+
}
639+
640+
let cache = self.enhanced_tools_cache.as_mut().unwrap();
641+
642+
for (origin, new_tools) in rag_mcp_tools {
643+
let existing_tools = cache.entry(origin).or_insert_with(Vec::new);
644+
645+
for new_tool in new_tools {
646+
let Tool::ToolSpecification(new_spec) = &new_tool;
647+
// Only add if we haven't seen this tool name anywhere in the cache
648+
if existing_tool_names.insert(new_spec.name.clone()) {
649+
existing_tools.push(new_tool);
650+
}
651+
}
652+
}
653+
}
654+
Ok(())
655+
}
656+
642657
/// Returns a [FigConversationState] capable of replacing the history of the current
643658
/// conversation with a summary generated by the model.
644659
///

0 commit comments

Comments
 (0)