Skip to content

Commit 1df0d42

Browse files
authored
Merge pull request #704 from apecloud/feature/improve_knowledge_pipeline
feat: refactor knowledge_pipeline
2 parents 3cb3dd9 + e8b125f commit 1df0d42

File tree

1 file changed

+195
-118
lines changed

1 file changed

+195
-118
lines changed

aperag/pipeline/knowledge_pipeline.py

Lines changed: 195 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import json
1717
import logging
1818
import random
19+
from typing import List, Tuple, Optional
1920

2021
from langchain_core.prompts import PromptTemplate
2122

@@ -112,147 +113,223 @@ async def filter_by_keywords(self, message, candidates):
112113
result.append(item)
113114
return result
114115

115-
async def run(self, message, gen_references=False, message_id=""):
116-
log_prefix = f"{message_id}|{message}"
117-
logger.info("[%s] start processing", log_prefix)
116+
async def _run_standard_rag(self, query_with_history: str, vector: List[float], log_prefix: str) -> Tuple[str, List[DocumentWithScore]]:
117+
"""
118+
Executes the standard RAG pipeline: vector search, rerank, keyword filtering, context packing.
119+
Returns the packed context string and the list of candidate documents.
120+
"""
121+
logger.info("[%s] Running standard RAG pipeline", log_prefix)
122+
results = await async_run(self.context_manager.query, query_with_history,
123+
score_threshold=self.score_threshold, topk=self.topk * 6, vector=vector)
124+
logger.info("[%s] Found %d relevant documents in vector db", log_prefix, len(results))
125+
126+
if self.bot_context != "":
127+
bot_context_result = DocumentWithScore(
128+
text=self.bot_context, # type: ignore
129+
score=0, # Use score 0 to easily identify and filter later if needed
130+
metadata={} # Add empty metadata
131+
)
132+
results.append(bot_context_result)
133+
134+
if len(results) > 1:
135+
results = await rerank(query_with_history, results) # Use query_with_history for reranking
136+
logger.info("[%s] Reranked %d candidates", log_prefix, len(results))
137+
else:
138+
logger.info("[%s] No need to rerank (candidates <= 1)", log_prefix)
118139

119-
history = []
120-
tot_history_querys = ''
121-
messages = await self.history.messages
122-
history_querys = [json.loads(message.content)["query"] for message in messages if message.additional_kwargs["role"] == "human"]
140+
candidates = results[:self.topk]
141+
142+
if self.enable_keyword_recall:
143+
candidates = await self.filter_by_keywords(query_with_history.split('\n')[-1], candidates) # Use original message for keywords
144+
logger.info("[%s] Filtered candidates by keyword, %d remaining", log_prefix, len(candidates))
145+
else:
146+
logger.info("[%s] Keyword filtering disabled", log_prefix)
147+
148+
context = ""
149+
if len(candidates) > 0:
150+
# 500 is the estimated length of the prompt and memory overhead
151+
context_allowance = max(self.context_window - 500, 0)
152+
context = get_packed_answer(candidates, context_allowance)
153+
logger.info("[%s] Packed context generated (length: %d)", log_prefix, len(context))
154+
else:
155+
logger.info("[%s] No candidates found after filtering", log_prefix)
123156

124-
if self.memory:
125-
tot_history_querys = '\n'.join(history_querys[-self.memory_limit_count:])+'\n'
157+
return context, candidates
126158

159+
async def _run_light_rag(self, query_with_history: str, log_prefix: str) -> Optional[str]:
160+
"""
161+
Placeholder for executing the LightRAG (GraphRAG) pipeline.
162+
This function will be implemented in a future PR.
163+
It should take the query and return the context string.
164+
"""
165+
logger.info("[%s] Skipping LightRAG pipeline (placeholder)", log_prefix)
166+
return None
167+
168+
async def run(self, message, gen_references=False, message_id=""):
169+
log_prefix = f"{message_id}|{message}"
170+
logger.info("[%s] Start processing request", log_prefix)
171+
172+
# --- 1. Common Setup & History Processing ---
173+
response = ""
127174
references = []
128175
related_questions = set()
129-
response = ""
130176
document_url_list = []
131177
document_url_set = set()
178+
context = ""
179+
candidates = [] # Keep track of candidates for references/URLs
180+
related_question_task = None
132181
need_generate_answer = True
133182
need_related_question = True
134-
vector = self.embedding_model.embed_query(tot_history_querys+message)
135-
logger.info("[%s] embedding query end", log_prefix)
136-
# hyde_task = asyncio.create_task(self.generate_hyde_message(message))
137-
138-
results = await async_run(self.qa_context_manager.query, tot_history_querys + message, score_threshold=0.5, topk=6, vector=vector)
139-
logger.info("[%s] find relevant qa pairs in vector db end", log_prefix)
140-
for result in results:
141-
result_text = json.loads(result.text)
142-
if result_text["answer"] != "" and result.score > 0.9:
143-
response = result_text["answer"]
144-
if result.score < 0.8:
145-
related_questions.add(result_text["question"])
146183

147-
# if len(related_questions) >= 3:
148-
# need_related_question = False
149-
150-
if response != "":
151-
yield response
152-
153-
if self.use_related_question and need_related_question:
154-
related_question_prompt = self.related_question_prompt.format(query=message, context=response)
155-
related_question_task = asyncio.create_task(self.generate_related_question(related_question_prompt))
156-
157-
else:
158-
results = await async_run(self.context_manager.query, tot_history_querys + message,
159-
score_threshold=self.score_threshold, topk=self.topk * 6, vector=vector)
160-
logger.info("[%s] find top %d relevant context in vector db end", log_prefix, len(results))
161-
# hyde_message = await hyde_task
162-
# new_vector = self.embedding_model.embed_query(hyde_message)
163-
# results2 = await async_run(self.context_manager.query, message,
164-
# score_threshold=self.score_threshold, topk=self.topk * 6, vector=new_vector)
165-
# results_set = set([result.text for result in results])
166-
# results.extend(result for result in results2 if result.text not in results_set)
167-
168-
if self.bot_context != "":
169-
bot_context_result = DocumentWithScore(
170-
text=self.bot_context, # type: ignore
171-
score=0,
172-
)
173-
results.append(bot_context_result)
174-
175-
if len(results) > 1:
176-
results = await rerank(message, results)
177-
logger.info("[%s] rerank candidates end", log_prefix)
178-
else:
179-
logger.info("[%s] don't need to rerank ", log_prefix)
180-
181-
candidates = results[:self.topk]
182-
183-
if self.enable_keyword_recall:
184-
candidates = await self.filter_by_keywords(message, candidates)
185-
logger.info("[%s] filter keyword end", log_prefix)
186-
else:
187-
logger.info("[%s] no need to filter keyword", log_prefix)
188-
189-
context = ""
190-
if len(candidates) > 0:
191-
# 500 is the estimated length of the prompt
192-
context = get_packed_answer(candidates, max(self.context_window - 500, 0))
193-
else:
184+
messages = await self.history.messages
185+
history_querys = [json.loads(msg.content)["query"] for msg in messages if msg.additional_kwargs.get("role") == "human"]
186+
tot_history_querys = '\n'.join(history_querys[-self.memory_limit_count:]) + '\n' if self.memory else ''
187+
query_with_history = tot_history_querys + message
188+
189+
# --- 2. QA Cache Check (Optional Shortcut) ---
190+
logger.info("[%s] Checking QA cache", log_prefix)
191+
vector = self.embedding_model.embed_query(query_with_history) # Embedding needed for QA cache and standard RAG
192+
logger.info("[%s] Query embedded", log_prefix)
193+
qa_results = await async_run(self.qa_context_manager.query, query_with_history, score_threshold=0.5, topk=6, vector=vector)
194+
logger.info("[%s] QA cache query returned %d results", log_prefix, len(qa_results))
195+
196+
cached_answer_found = False
197+
for result in qa_results:
198+
try:
199+
result_text = json.loads(result.text)
200+
if result_text.get("answer") and result.score > 0.9: # High confidence match
201+
response = result_text["answer"]
202+
context = response # Use cached answer as context for related questions
203+
cached_answer_found = True
204+
need_generate_answer = False # No need to call LLM
205+
logger.info("[%s] Found high-confidence answer in QA cache.", log_prefix)
206+
yield response # Start yielding cached answer
207+
break # Stop after finding one good answer
208+
elif result.score >= 0.8: # Add potential related questions from cache
209+
related_questions.add(result_text["question"])
210+
except (json.JSONDecodeError, KeyError) as e:
211+
logger.warning("[%s] Failed to parse QA cache result: %s, error: %s", log_prefix, result.text, e)
212+
213+
214+
# --- 3. Main RAG Processing (if no QA cache hit) ---
215+
if not cached_answer_found:
216+
logger.info("[%s] No high-confidence answer in QA cache, proceeding with RAG pipeline", log_prefix)
217+
218+
# --- 3a. Choose and Run RAG method(s) ---
219+
# For now, we run standard RAG. LightRAG is a placeholder.
220+
# In the future, logic can be added here to choose, combine, or run in parallel.
221+
standard_context, candidates = await self._run_standard_rag(query_with_history, vector, log_prefix)
222+
# lightrag_context = await self._run_light_rag(query_with_history, log_prefix)
223+
224+
# --- 3b. Select Context ---
225+
# For now, prioritize standard RAG context. Future logic could combine or select.
226+
context = standard_context
227+
# if lightrag_context: # Example of prioritizing LightRAG if available
228+
# context = lightrag_context
229+
# candidates = [] # Clear candidates if LightRAG context is used (no direct candidates from it yet)
230+
231+
# --- 3c. Handle No Context Found ---
232+
if not context:
194233
if self.oops != "":
195234
response = self.oops
196235
yield self.oops
197236
need_generate_answer = False
237+
logger.info("[%s] No context found, yielding 'oops' message.", log_prefix)
198238
if self.welcome_question:
199239
related_questions.update(self.welcome_question)
200-
if len(related_questions) >= 3:
201-
need_related_question = False
202-
203-
if self.use_related_question and need_related_question:
204-
related_question_prompt = self.related_question_prompt.format(query=message, context=context)
205-
related_question_task = asyncio.create_task(self.generate_related_question(related_question_prompt))
206-
207-
if need_generate_answer:
208-
history_querys.append(message)
209-
related_questions = related_questions - set(history_querys[-5:])
210-
if self.memory and len(messages) > 0:
211-
history = self.predictor.get_latest_history(
212-
messages=messages,
213-
limit_length=max(min(self.context_window - 500 - len(context), self.memory_limit_length), 0),
214-
limit_count=self.memory_limit_count,
215-
use_ai_memory=self.use_ai_memory)
216-
self.memory_count = len(history)
217-
218-
prompt = self.prompt.format(query=message, context=context)
219-
logger.info("[%s] final prompt is\n%s", log_prefix, prompt)
220-
221-
async for msg in self.predictor.agenerate_stream(history, prompt, self.memory):
222-
yield msg
223-
response += msg
224-
225-
for result in candidates:
226-
# filter bot_context
227-
if result.score == 0:
228-
continue
229-
references.append({
230-
"score": result.score,
231-
"text": result.text,
232-
"metadata": result.metadata
233-
})
234-
url = result.metadata.get("url")
235-
if url and url not in document_url_set:
236-
document_url_set.add(result.metadata.get("url"))
237-
document_url_list.append(result.metadata.get("url"))
238-
240+
logger.info("[%s] Adding welcome questions as related questions.", log_prefix)
241+
242+
243+
# --- 4. Generate Related Questions (if enabled) ---
244+
if self.use_related_question and need_related_question:
245+
# Only start the task if we have some context (either from cache or RAG) or no context but welcome questions
246+
if context or (not context and self.welcome_question):
247+
# Check if we already have enough related questions from QA cache or welcome questions
248+
if len(related_questions) < 3:
249+
related_question_prompt_context = context if context else "No context found." # Provide some context even if empty
250+
related_question_prompt = self.related_question_prompt.format(query=message, context=related_question_prompt_context)
251+
related_question_task = asyncio.create_task(self.generate_related_question(related_question_prompt))
252+
logger.info("[%s] Created related question generation task.", log_prefix)
253+
else:
254+
logger.info("[%s] Skipping related question generation task (already have %d).", log_prefix, len(related_questions))
255+
else:
256+
logger.info("[%s] Skipping related question generation task (no context and no welcome questions).", log_prefix)
257+
258+
259+
# --- 5. Generate LLM Answer (if needed) ---
260+
if need_generate_answer:
261+
logger.info("[%s] Generating LLM answer.", log_prefix)
262+
history = []
263+
if self.memory and len(messages) > 0:
264+
history_context_allowance = max(min(self.context_window - 500 - len(context), self.memory_limit_length), 0)
265+
history = self.predictor.get_latest_history(
266+
messages=messages,
267+
limit_length=history_context_allowance,
268+
limit_count=self.memory_limit_count,
269+
use_ai_memory=self.use_ai_memory)
270+
self.memory_count = len(history)
271+
logger.info("[%s] Prepared %d history entries for LLM.", log_prefix, len(history))
272+
273+
prompt = self.prompt.format(query=message, context=context)
274+
logger.debug("[%s] Final prompt for LLM:\n%s", log_prefix, prompt) # Use debug level for potentially long prompts
275+
276+
async for msg_chunk in self.predictor.agenerate_stream(history, prompt, self.memory):
277+
yield msg_chunk
278+
response += msg_chunk
279+
logger.info("[%s] LLM stream finished.", log_prefix)
280+
281+
# Populate references and URLs from the candidates used for the context
282+
for result in candidates:
283+
# Filter out bot_context placeholder if it exists and wasn't filtered earlier
284+
if result.score == 0 and result.text == self.bot_context:
285+
continue
286+
references.append({
287+
"score": result.score,
288+
"text": result.text,
289+
"metadata": result.metadata
290+
})
291+
url = result.metadata.get("url")
292+
if url and url not in document_url_set:
293+
document_url_set.add(url)
294+
document_url_list.append(url)
295+
296+
# --- 6. Finalization: Save Messages & Yield Metadata ---
239297
await self.add_human_message(message, message_id)
240-
logger.info("[%s] add human message end", log_prefix)
298+
logger.info("[%s] Human message saved.", log_prefix)
241299

300+
# Ensure AI message includes references/URLs derived from the context used
242301
await self.add_ai_message(message, message_id, response, references, document_url_list)
243-
logger.info("[%s] add ai message end and the pipeline is succeed", log_prefix)
302+
logger.info("[%s] AI message saved.", log_prefix)
244303

304+
# Yield related questions if generated/collected
245305
if self.use_related_question:
246-
if need_related_question:
247-
related_question_generate = await related_question_task
248-
related_questions.update(related_question_generate)
249-
related_questions = list(related_questions)
250-
random.shuffle(related_questions)
251-
yield RELATED_QUESTIONS + str(related_questions[:3])
252-
253-
if gen_references:
306+
final_related_questions = list(related_questions)
307+
if related_question_task:
308+
try:
309+
generated_questions = await related_question_task
310+
logger.info("[%s] Related question generation task finished.", log_prefix)
311+
# Avoid duplicates and filter out recent history
312+
history_querys.append(message) # Add current message to history for filtering
313+
recent_queries = set(history_querys[-5:])
314+
for q in generated_questions:
315+
if q not in final_related_questions and q not in recent_queries:
316+
final_related_questions.append(q)
317+
except Exception as e:
318+
logger.error("[%s] Related question generation failed: %s", log_prefix, e)
319+
320+
if final_related_questions:
321+
random.shuffle(final_related_questions)
322+
yield RELATED_QUESTIONS + json.dumps(final_related_questions[:3])
323+
logger.info("[%s] Yielded related questions.", log_prefix)
324+
325+
# Yield references if requested
326+
if gen_references and references:
254327
yield DOC_QA_REFERENCES + json.dumps(references)
328+
logger.info("[%s] Yielded references.", log_prefix)
255329

330+
# Yield document URLs if available
256331
if document_url_list:
257332
yield DOCUMENT_URLS + json.dumps(document_url_list)
333+
logger.info("[%s] Yielded document URLs.", log_prefix)
258334

335+
logger.info("[%s] Processing finished successfully.", log_prefix)

0 commit comments

Comments
 (0)