|
| 1 | +from .configuration import Configuration |
| 2 | +from langchain_core.runnables import RunnableConfig |
| 3 | +from .state import State |
| 4 | +from typing import Any, Dict |
| 5 | +from app.utils.connector_service import ConnectorService |
| 6 | +from app.utils.reranker_service import RerankerService |
| 7 | +from app.config import config as app_config |
| 8 | +from .prompts import citation_system_prompt |
| 9 | +from langchain_core.messages import HumanMessage, SystemMessage |
| 10 | + |
| 11 | +async def fetch_relevant_documents(state: State, config: RunnableConfig) -> Dict[str, Any]: |
| 12 | + """ |
| 13 | + Fetch relevant documents for the sub-section using specified connectors. |
| 14 | + |
| 15 | + This node retrieves documents from various data sources based on the sub-questions |
| 16 | + derived from the sub-section title. It searches across all selected connectors |
| 17 | + (YouTube, Extension, Crawled URLs, Files, Tavily API, Slack, Notion) and reranks |
| 18 | + the results to provide the most relevant information for the agent workflow. |
| 19 | + |
| 20 | + Returns: |
| 21 | + Dict containing the reranked documents in the "relevant_documents_fetched" key. |
| 22 | + """ |
| 23 | + # Get configuration |
| 24 | + configuration = Configuration.from_runnable_config(config) |
| 25 | + |
| 26 | + # Extract state parameters |
| 27 | + db_session = state.db_session |
| 28 | + |
| 29 | + # Extract config parameters |
| 30 | + user_id = configuration.user_id |
| 31 | + search_space_id = configuration.search_space_id |
| 32 | + TOP_K = configuration.top_k |
| 33 | + |
| 34 | + # Initialize services |
| 35 | + connector_service = ConnectorService(db_session) |
| 36 | + reranker_service = RerankerService.get_reranker_instance(app_config) |
| 37 | + |
| 38 | + all_raw_documents = [] # Store all raw documents before reranking |
| 39 | + |
| 40 | + for user_query in configuration.sub_questions: |
| 41 | + # Reformulate query (optional, consider if needed for each sub-question) |
| 42 | + # reformulated_query = await QueryService.reformulate_query(user_query) |
| 43 | + reformulated_query = user_query # Using original sub-question for now |
| 44 | + |
| 45 | + # Process each selected connector |
| 46 | + for connector in configuration.connectors_to_search: |
| 47 | + if connector == "YOUTUBE_VIDEO": |
| 48 | + _, youtube_chunks = await connector_service.search_youtube( |
| 49 | + user_query=reformulated_query, |
| 50 | + user_id=user_id, |
| 51 | + search_space_id=search_space_id, |
| 52 | + top_k=TOP_K |
| 53 | + ) |
| 54 | + all_raw_documents.extend(youtube_chunks) |
| 55 | + |
| 56 | + elif connector == "EXTENSION": |
| 57 | + _, extension_chunks = await connector_service.search_extension( |
| 58 | + user_query=reformulated_query, |
| 59 | + user_id=user_id, |
| 60 | + search_space_id=search_space_id, |
| 61 | + top_k=TOP_K |
| 62 | + ) |
| 63 | + all_raw_documents.extend(extension_chunks) |
| 64 | + |
| 65 | + elif connector == "CRAWLED_URL": |
| 66 | + _, crawled_urls_chunks = await connector_service.search_crawled_urls( |
| 67 | + user_query=reformulated_query, |
| 68 | + user_id=user_id, |
| 69 | + search_space_id=search_space_id, |
| 70 | + top_k=TOP_K |
| 71 | + ) |
| 72 | + all_raw_documents.extend(crawled_urls_chunks) |
| 73 | + |
| 74 | + elif connector == "FILE": |
| 75 | + _, files_chunks = await connector_service.search_files( |
| 76 | + user_query=reformulated_query, |
| 77 | + user_id=user_id, |
| 78 | + search_space_id=search_space_id, |
| 79 | + top_k=TOP_K |
| 80 | + ) |
| 81 | + all_raw_documents.extend(files_chunks) |
| 82 | + |
| 83 | + elif connector == "TAVILY_API": |
| 84 | + _, tavily_chunks = await connector_service.search_tavily( |
| 85 | + user_query=reformulated_query, |
| 86 | + user_id=user_id, |
| 87 | + top_k=TOP_K |
| 88 | + ) |
| 89 | + all_raw_documents.extend(tavily_chunks) |
| 90 | + |
| 91 | + elif connector == "SLACK_CONNECTOR": |
| 92 | + _, slack_chunks = await connector_service.search_slack( |
| 93 | + user_query=reformulated_query, |
| 94 | + user_id=user_id, |
| 95 | + search_space_id=search_space_id, |
| 96 | + top_k=TOP_K |
| 97 | + ) |
| 98 | + all_raw_documents.extend(slack_chunks) |
| 99 | + |
| 100 | + elif connector == "NOTION_CONNECTOR": |
| 101 | + _, notion_chunks = await connector_service.search_notion( |
| 102 | + user_query=reformulated_query, |
| 103 | + user_id=user_id, |
| 104 | + search_space_id=search_space_id, |
| 105 | + top_k=TOP_K |
| 106 | + ) |
| 107 | + all_raw_documents.extend(notion_chunks) |
| 108 | + |
| 109 | + # If we have documents and a reranker is available, rerank them |
| 110 | + # Deduplicate documents based on chunk_id or content to avoid processing duplicates |
| 111 | + seen_chunk_ids = set() |
| 112 | + seen_content_hashes = set() |
| 113 | + deduplicated_docs = [] |
| 114 | + |
| 115 | + for doc in all_raw_documents: |
| 116 | + chunk_id = doc.get("chunk_id") |
| 117 | + content = doc.get("content", "") |
| 118 | + content_hash = hash(content) |
| 119 | + |
| 120 | + # Skip if we've seen this chunk_id or content before |
| 121 | + if (chunk_id and chunk_id in seen_chunk_ids) or content_hash in seen_content_hashes: |
| 122 | + continue |
| 123 | + |
| 124 | + # Add to our tracking sets and keep this document |
| 125 | + if chunk_id: |
| 126 | + seen_chunk_ids.add(chunk_id) |
| 127 | + seen_content_hashes.add(content_hash) |
| 128 | + deduplicated_docs.append(doc) |
| 129 | + |
| 130 | + # Use deduplicated documents for reranking |
| 131 | + reranked_docs = deduplicated_docs |
| 132 | + if deduplicated_docs and reranker_service: |
| 133 | + # Use the main sub_section_title for reranking context |
| 134 | + rerank_query = configuration.sub_section_title |
| 135 | + |
| 136 | + # Convert documents to format expected by reranker |
| 137 | + reranker_input_docs = [ |
| 138 | + { |
| 139 | + "chunk_id": doc.get("chunk_id", f"chunk_{i}"), |
| 140 | + "content": doc.get("content", ""), |
| 141 | + "score": doc.get("score", 0.0), |
| 142 | + "document": { |
| 143 | + "id": doc.get("document", {}).get("id", ""), |
| 144 | + "title": doc.get("document", {}).get("title", ""), |
| 145 | + "document_type": doc.get("document", {}).get("document_type", ""), |
| 146 | + "metadata": doc.get("document", {}).get("metadata", {}) |
| 147 | + } |
| 148 | + } for i, doc in enumerate(deduplicated_docs) |
| 149 | + ] |
| 150 | + |
| 151 | + # Rerank documents using the main title query |
| 152 | + reranked_docs = reranker_service.rerank_documents(rerank_query, reranker_input_docs) |
| 153 | + |
| 154 | + # Sort by score in descending order |
| 155 | + reranked_docs.sort(key=lambda x: x.get("score", 0), reverse=True) |
| 156 | + |
| 157 | + # Update state with fetched documents |
| 158 | + return { |
| 159 | + "relevant_documents_fetched": reranked_docs |
| 160 | + } |
| 161 | + |
| 162 | + |
| 163 | + |
| 164 | +async def write_sub_section(state: State, config: RunnableConfig) -> Dict[str, Any]: |
| 165 | + """ |
| 166 | + Write the sub-section using the fetched documents. |
| 167 | + |
| 168 | + This node takes the relevant documents fetched in the previous node and uses |
| 169 | + an LLM to generate a comprehensive answer to the sub-section questions with |
| 170 | + proper citations. The citations follow IEEE format using source IDs from the |
| 171 | + documents. |
| 172 | + |
| 173 | + Returns: |
| 174 | + Dict containing the final answer in the "final_answer" key. |
| 175 | + """ |
| 176 | + |
| 177 | + # Get configuration and relevant documents |
| 178 | + configuration = Configuration.from_runnable_config(config) |
| 179 | + documents = state.relevant_documents_fetched |
| 180 | + |
| 181 | + # Initialize LLM |
| 182 | + llm = app_config.fast_llm_instance |
| 183 | + |
| 184 | + # If no documents were found, return a message indicating this |
| 185 | + if not documents or len(documents) == 0: |
| 186 | + return { |
| 187 | + "final_answer": "No relevant documents were found to answer this question. Please try refining your search or providing more specific questions." |
| 188 | + } |
| 189 | + |
| 190 | + # Prepare documents for citation formatting |
| 191 | + formatted_documents = [] |
| 192 | + for i, doc in enumerate(documents): |
| 193 | + # Extract content and metadata |
| 194 | + content = doc.get("content", "") |
| 195 | + doc_info = doc.get("document", {}) |
| 196 | + document_id = doc_info.get("id", f"{i+1}") # Use document ID or index+1 as source_id |
| 197 | + |
| 198 | + # Format document according to the citation system prompt's expected format |
| 199 | + formatted_doc = f""" |
| 200 | + <document> |
| 201 | + <metadata> |
| 202 | + <source_id>{document_id}</source_id> |
| 203 | + </metadata> |
| 204 | + <content> |
| 205 | + {content} |
| 206 | + </content> |
| 207 | + </document> |
| 208 | + """ |
| 209 | + formatted_documents.append(formatted_doc) |
| 210 | + |
| 211 | + # Create the query that combines the section title and questions |
| 212 | + # section_title = configuration.sub_section_title |
| 213 | + questions = "\n".join([f"- {q}" for q in configuration.sub_questions]) |
| 214 | + documents_text = "\n".join(formatted_documents) |
| 215 | + |
| 216 | + # Construct a clear, structured query for the LLM |
| 217 | + human_message_content = f""" |
| 218 | + Please write a comprehensive answer for the title: |
| 219 | + |
| 220 | + Address the following questions: |
| 221 | + <questions> |
| 222 | + {questions} |
| 223 | + </questions> |
| 224 | +
|
| 225 | + Use the provided documents as your source material and cite them properly using the IEEE citation format [X] where X is the source_id. |
| 226 | + <documents> |
| 227 | + {documents_text} |
| 228 | + </documents> |
| 229 | + """ |
| 230 | + |
| 231 | + # Create messages for the LLM |
| 232 | + messages = [ |
| 233 | + SystemMessage(content=citation_system_prompt), |
| 234 | + HumanMessage(content=human_message_content) |
| 235 | + ] |
| 236 | + |
| 237 | + # Call the LLM and get the response |
| 238 | + response = await llm.ainvoke(messages) |
| 239 | + final_answer = response.content |
| 240 | + |
| 241 | + return { |
| 242 | + "final_answer": final_answer |
| 243 | + } |
| 244 | + |
0 commit comments