Skip to content

Commit b4d48a6

Browse files
authored
chore: support multiple source search in chat (#7906)
* chore: support multiple source search in chat * chore: clippy
1 parent c986e41 commit b4d48a6

File tree

22 files changed

+642
-248
lines changed

22 files changed

+642
-248
lines changed

frontend/rust-lib/flowy-ai/src/ai_manager.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,6 @@ impl AIManager {
217217
let summary = select_chat_summary(&mut conn, chat_id).unwrap_or_default();
218218

219219
let model = self.get_active_model(&chat_id.to_string()).await;
220-
trace!("[AI Plugin] notify open chat: {}", chat_id);
221220
self
222221
.local_ai
223222
.open_chat(&workspace_id, chat_id, &model.name, rag_ids, summary)
@@ -240,7 +239,7 @@ impl AIManager {
240239
.await
241240
{
242241
Ok(settings) => {
243-
local_ai.set_rag_ids(&chat_id, &settings.rag_ids);
242+
local_ai.set_rag_ids(&chat_id, &settings.rag_ids).await;
244243
let rag_ids = settings
245244
.rag_ids
246245
.into_iter()
@@ -712,7 +711,7 @@ impl AIManager {
712711

713712
let user_service = self.user_service.clone();
714713
let external_service = self.external_service.clone();
715-
self.local_ai.set_rag_ids(chat_id, &rag_ids);
714+
self.local_ai.set_rag_ids(chat_id, &rag_ids).await;
716715

717716
let rag_ids = rag_ids
718717
.into_iter()

frontend/rust-lib/flowy-ai/src/embeddings/store.rs

Lines changed: 86 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::embeddings::document_indexer::split_text_into_chunks;
22
use crate::embeddings::embedder::{Embedder, OllamaEmbedder};
33
use crate::embeddings::indexer::{EmbeddingModel, IndexerProvider};
4+
use crate::local_ai::chat::retriever::MultipleSourceRetrieverStore;
45
use async_trait::async_trait;
56
use flowy_ai_pub::cloud::CollabType;
67
use flowy_ai_pub::entities::{RAG_IDS, SOURCE_ID};
@@ -9,10 +10,8 @@ use flowy_sqlite_vec::db::VectorSqliteDB;
910
use flowy_sqlite_vec::entities::{EmbeddedContent, SqliteEmbeddedDocument};
1011
use futures::stream::{self, StreamExt};
1112
use langchain_rust::llm::client::OllamaClient;
12-
use langchain_rust::{
13-
schemas::Document,
14-
vectorstore::{VecStoreOptions, VectorStore},
15-
};
13+
use langchain_rust::schemas::Document;
14+
use langchain_rust::vectorstore::{VecStoreOptions, VectorStore};
1615
use ollama_rs::generation::embeddings::request::{EmbeddingsInput, GenerateEmbeddingsRequest};
1716
use serde_json::Value;
1817
use std::collections::HashMap;
@@ -86,6 +85,80 @@ impl SqliteVectorStore {
8685
}
8786
}
8887

88+
#[async_trait]
89+
impl MultipleSourceRetrieverStore for SqliteVectorStore {
90+
fn retriever_name(&self) -> &'static str {
91+
"Sqlite Multiple Source Retriever"
92+
}
93+
94+
async fn read_documents(
95+
&self,
96+
workspace_id: &Uuid,
97+
query: &str,
98+
limit: usize,
99+
rag_ids: &[String],
100+
score_threshold: f32,
101+
_full_search: bool,
102+
) -> FlowyResult<Vec<Document>> {
103+
let vector_db = match self.vector_db.upgrade() {
104+
Some(db) => db,
105+
None => return Err(FlowyError::internal().with_context("Vector database not initialized")),
106+
};
107+
108+
// Create embedder and generate embedding for query
109+
let embedder = self.create_embedder()?;
110+
let request = GenerateEmbeddingsRequest::new(
111+
embedder.model().name().to_string(),
112+
EmbeddingsInput::Single(query.to_string()),
113+
);
114+
115+
let embedding = embedder.embed(request).await?.embeddings;
116+
if embedding.is_empty() {
117+
return Ok(Vec::new());
118+
}
119+
120+
debug_assert!(embedding.len() == 1);
121+
let query_embedding = embedding.first().unwrap();
122+
123+
// Perform similarity search in the database
124+
let results = vector_db
125+
.search_with_score(
126+
&workspace_id.to_string(),
127+
rag_ids,
128+
query_embedding,
129+
limit as i32,
130+
score_threshold,
131+
)
132+
.await?;
133+
134+
trace!(
135+
"[VectorStore] Found {} results for query:{}, rag_ids: {:?}, score_threshold: {}",
136+
results.len(),
137+
query,
138+
rag_ids,
139+
score_threshold
140+
);
141+
142+
// Convert results to Documents
143+
let documents = results
144+
.into_iter()
145+
.map(|result| {
146+
let mut metadata = HashMap::new();
147+
148+
if let Some(map) = result.metadata.as_ref().and_then(|v| v.as_object()) {
149+
for (key, value) in map {
150+
metadata.insert(key.clone(), value.clone());
151+
}
152+
}
153+
154+
Document::new(result.content).with_metadata(metadata)
155+
})
156+
.collect();
157+
158+
Ok(documents)
159+
}
160+
}
161+
89162
#[async_trait]
90163
impl VectorStore for SqliteVectorStore {
91164
type Options = VecStoreOptions<Value>;
@@ -215,74 +288,23 @@ impl VectorStore for SqliteVectorStore {
215288

216289
// Return empty result if workspace_id is missing
217290
let workspace_id = match workspace_id {
218-
Some(id) => id.to_string(),
291+
Some(id) => id,
219292
None => {
220293
warn!("[VectorStore] Missing workspace_id in filters. Returning empty result.");
221294
return Ok(Vec::new());
222295
},
223296
};
224297

225-
// Get the vector database
226-
let vector_db = match self.vector_db.upgrade() {
227-
Some(db) => db,
228-
None => return Err("Vector database not initialized".into()),
229-
};
230-
231-
// Create embedder and generate embedding for query
232-
let embedder = self.create_embedder()?;
233-
let request = GenerateEmbeddingsRequest::new(
234-
embedder.model().name().to_string(),
235-
EmbeddingsInput::Single(query.to_string()),
236-
);
237-
238-
let embedding = match embedder.embed(request).await {
239-
Ok(result) => result.embeddings,
240-
Err(e) => return Err(Box::new(e)),
241-
};
242-
243-
if embedding.is_empty() {
244-
return Ok(Vec::new());
245-
}
246-
247-
let score_threshold = opt.score_threshold.unwrap_or(0.4);
248-
debug_assert!(embedding.len() == 1);
249-
let query_embedding = embedding.first().unwrap();
250-
251-
// Perform similarity search in the database
252-
let results = vector_db
253-
.search_with_score(
298+
self
299+
.read_documents(
254300
&workspace_id,
301+
query,
302+
limit,
255303
&rag_ids,
256-
query_embedding,
257-
limit as i32,
258-
score_threshold,
304+
opt.score_threshold.unwrap_or(0.4),
305+
true,
259306
)
260-
.await?;
261-
262-
trace!(
263-
"[VectorStore] Found {} results for query:{}, rag_ids: {:?}, score_threshold: {}",
264-
results.len(),
265-
query,
266-
rag_ids,
267-
score_threshold
268-
);
269-
270-
// Convert results to Documents
271-
let documents = results
272-
.into_iter()
273-
.map(|result| {
274-
let mut metadata = HashMap::new();
275-
276-
if let Some(map) = result.metadata.as_ref().and_then(|v| v.as_object()) {
277-
for (key, value) in map {
278-
metadata.insert(key.clone(), value.clone());
279-
}
280-
}
281-
282-
Document::new(result.content).with_metadata(metadata)
283-
})
284-
.collect();
285-
286-
Ok(documents)
307+
.await
308+
.map_err(|err| Box::new(err) as Box<dyn Error>)
287309
}
288310
}

frontend/rust-lib/flowy-ai/src/local_ai/chat/chains/context_question_chain.rs

Lines changed: 48 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::local_ai::chat::llm::LLMOllama;
22
use crate::SqliteVectorStore;
33
use flowy_error::{FlowyError, FlowyResult};
4+
use flowy_sqlite_vec::entities::EmbeddedContent;
45
use langchain_rust::language_models::llm::LLM;
56
use langchain_rust::prompt::TemplateFormat;
67
use langchain_rust::prompt::{PromptFromatter, PromptTemplate};
@@ -10,6 +11,7 @@ use ollama_rs::generation::parameters::{FormatType, JsonStructure};
1011
use schemars::JsonSchema;
1112
use serde::{Deserialize, Serialize};
1213
use serde_json::json;
14+
use std::fmt::Debug;
1315
use tracing::trace;
1416
use uuid::Uuid;
1517

@@ -60,13 +62,13 @@ pub struct ContextQuestion {
6062
pub object_id: String,
6163
}
6264

63-
pub struct RelatedQuestionChain {
65+
pub struct ContextRelatedQuestionChain {
6466
workspace_id: Uuid,
6567
llm: LLMOllama,
6668
store: SqliteVectorStore,
6769
}
6870

69-
impl RelatedQuestionChain {
71+
impl ContextRelatedQuestionChain {
7072
pub fn new(workspace_id: Uuid, ollama: LLMOllama, store: SqliteVectorStore) -> Self {
7173
let format = FormatType::StructuredJson(JsonStructure::new::<ContextQuestionsResponse>());
7274
Self {
@@ -76,25 +78,16 @@ impl RelatedQuestionChain {
7678
}
7779
}
7880

79-
pub async fn generate_questions(&self, rag_ids: &[String]) -> FlowyResult<Vec<ContextQuestion>> {
80-
trace!(
81-
"[embedding] Generating context related questions for RAG IDs: {:?}",
82-
rag_ids
83-
);
84-
85-
let context = self
86-
.store
87-
.select_all_embedded_content(&self.workspace_id.to_string(), rag_ids, 3)
88-
.await?;
89-
90-
trace!(
91-
"[embedding] Generating related questions base on: {:?}",
92-
context,
93-
);
94-
95-
let context_str = json!(context).to_string();
81+
pub async fn generate_questions_from_context<T>(
82+
&self,
83+
rag_ids: &[T],
84+
context: &str,
85+
) -> FlowyResult<Vec<ContextQuestion>>
86+
where
87+
T: AsRef<str>,
88+
{
9689
let input_variables = prompt_args! {
97-
"context" => context_str,
90+
"context" => context,
9891
};
9992

10093
let template = PromptTemplate::new(
@@ -116,8 +109,42 @@ impl RelatedQuestionChain {
116109
// filter out questions that are not in the rag_ids
117110
parsed_result
118111
.questions
119-
.retain(|v| rag_ids.contains(&v.object_id));
112+
.retain(|v| rag_ids.iter().any(|id| id.as_ref() == v.object_id));
120113

121114
Ok(parsed_result.questions)
122115
}
116+
117+
pub async fn generate_questions<T>(
118+
&self,
119+
rag_ids: &[T],
120+
) -> FlowyResult<(String, Vec<ContextQuestion>)>
121+
where
122+
T: AsRef<str> + Debug,
123+
{
124+
trace!(
125+
"[embedding] Generating context related questions for RAG IDs: {:?}",
126+
rag_ids
127+
);
128+
129+
let rag_ids_str: Vec<String> = rag_ids.iter().map(|id| id.as_ref().to_string()).collect();
130+
let context = self
131+
.store
132+
.select_all_embedded_content(&self.workspace_id.to_string(), &rag_ids_str, 3)
133+
.await?;
134+
135+
trace!(
136+
"[embedding] Generating related questions base on: {:?}",
137+
context,
138+
);
139+
140+
let context_str = embedded_documents_to_context_str(context);
141+
self
142+
.generate_questions_from_context(rag_ids, &context_str)
143+
.await
144+
.map(|questions| (context_str, questions))
145+
}
146+
}
147+
148+
pub fn embedded_documents_to_context_str(documents: Vec<EmbeddedContent>) -> String {
149+
json!(documents).to_string()
123150
}

0 commit comments

Comments
 (0)