|
2 | 2 | import logging |
3 | 3 | from dotenv import load_dotenv |
4 | 4 |
|
| 5 | +from typing import Any |
5 | 6 | from fastapi import APIRouter |
6 | 7 | from langchain_google_vertexai import ChatVertexAI |
7 | 8 | from langchain_google_genai import ChatGoogleGenerativeAI |
|
88 | 89 |
|
89 | 90 | router = APIRouter(prefix="/graphs", tags=["graphs"]) |
90 | 91 |
|
| 92 | + |
| 93 | +def extract_rag_context_sources(output: list) -> list[ContextSource]: |
| 94 | + """Extract context sources from RAG agent output.""" |
| 95 | + context_sources = [] |
| 96 | + for element in output[1:-1]: # Skip first (classify) and last (generate) nodes |
| 97 | + if isinstance(element, dict): |
| 98 | + for key, value in element.items(): |
| 99 | + if key.startswith("retrieve_") and isinstance(value, dict): |
| 100 | + urls = value.get("urls", []) |
| 101 | + context = value.get("context", "") |
| 102 | + # Create context sources from urls and context |
| 103 | + for url in urls: |
| 104 | + context_sources.append( |
| 105 | + ContextSource(context=context, source=url) |
| 106 | + ) |
| 107 | + return context_sources |
| 108 | + |
| 109 | + |
| 110 | +def extract_mcp_context_sources(output: list) -> tuple[list[ContextSource], list[str]]: |
| 111 | + """Extract context sources and tools from MCP/arch agent output.""" |
| 112 | + context_sources = [] |
| 113 | + tools = [] |
| 114 | + |
| 115 | + if "agent" in output[0] and "tools" in output[0]["agent"]: |
| 116 | + tools = output[0]["agent"]["tools"] |
| 117 | + for tool_index in range(len(tools)): |
| 118 | + tool_output = list(output[tool_index + 1].values())[0] |
| 119 | + urls = tool_output.get("urls", []) |
| 120 | + context_list = tool_output.get("context_list", []) |
| 121 | + for _url, _context in zip(urls, context_list): |
| 122 | + context_sources.append(ContextSource(context=_context, source=_url)) |
| 123 | + |
| 124 | + return context_sources, tools |
| 125 | + |
| 126 | + |
| 127 | +def validate_output_structure(output: Any) -> bool: |
| 128 | + """Validate that output has the expected structure.""" |
| 129 | + return isinstance(output, list) and len(output) > 2 and isinstance(output[-1], dict) |
| 130 | + |
| 131 | + |
| 132 | +def log_invalid_output(output: Any) -> None: |
| 133 | + """Log details about invalid output structure.""" |
| 134 | + logging.error( |
| 135 | + f"Invalid output structure: type={type(output)}, len={len(output) if isinstance(output, list) else 'N/A'}" |
| 136 | + ) |
| 137 | + if isinstance(output, list) and len(output) > 0: |
| 138 | + logging.error( |
| 139 | + f"Last element keys: {output[-1].keys() if isinstance(output[-1], dict) else 'Not a dict'}" |
| 140 | + ) |
| 141 | + |
| 142 | + |
| 143 | +def get_agent_type(output: list) -> tuple[bool, str]: |
| 144 | + """Determine agent type from output structure. |
| 145 | + Returns: (is_rag_agent, generate_key) |
| 146 | + """ |
| 147 | + is_rag_agent = "rag_generate" in output[-1] |
| 148 | + generate_key = "rag_generate" if is_rag_agent else "generate" |
| 149 | + return is_rag_agent, generate_key |
| 150 | + |
| 151 | + |
| 152 | +def extract_llm_response(output: list, generate_key: str) -> str | None: |
| 153 | + """Extract LLM response from output if available.""" |
| 154 | + if generate_key not in output[-1]: |
| 155 | + logging.error(f"Missing {generate_key} key") |
| 156 | + return None |
| 157 | + |
| 158 | + generate_data = output[-1][generate_key] |
| 159 | + |
| 160 | + if "messages" not in generate_data: |
| 161 | + logging.error(f"Missing messages in {generate_key}") |
| 162 | + return None |
| 163 | + |
| 164 | + messages = generate_data["messages"] |
| 165 | + if not messages or len(messages) == 0: |
| 166 | + logging.error("No messages in generate output") |
| 167 | + return None |
| 168 | + |
| 169 | + return str(messages[0]) |
| 170 | + |
| 171 | + |
| 172 | +def parse_agent_output(output: list) -> tuple[str, list[ContextSource], list[str]]: |
| 173 | + """ |
| 174 | + Parse agent output and extract response, context sources, and tools. |
| 175 | + """ |
| 176 | + # Default return values |
| 177 | + default_response = "LLM response extraction failed" |
| 178 | + context_sources: list[ContextSource] = [] |
| 179 | + tools: list[str] = [] |
| 180 | + |
| 181 | + # Validate output structure |
| 182 | + if not validate_output_structure(output): |
| 183 | + log_invalid_output(output) |
| 184 | + return default_response, context_sources, tools |
| 185 | + |
| 186 | + # Determine agent type |
| 187 | + is_rag_agent, generate_key = get_agent_type(output) |
| 188 | + |
| 189 | + # Extract LLM response |
| 190 | + llm_response = extract_llm_response(output, generate_key) |
| 191 | + if llm_response is None: |
| 192 | + return default_response, context_sources, tools |
| 193 | + |
| 194 | + # Extract context sources based on agent type |
| 195 | + if is_rag_agent: |
| 196 | + context_sources = extract_rag_context_sources(output) |
| 197 | + else: |
| 198 | + context_sources, tools = extract_mcp_context_sources(output) |
| 199 | + |
| 200 | + return llm_response, context_sources, tools |
| 201 | + |
| 202 | + |
91 | 203 | rg = RetrieverGraph( |
92 | 204 | llm_model=llm, |
93 | 205 | embeddings_config=embeddings_config, |
@@ -123,67 +235,38 @@ async def get_agent_response(user_input: UserInput) -> ChatResponse: |
123 | 235 | output = list(rg.graph.stream(inputs, stream_mode="updates")) |
124 | 236 | else: |
125 | 237 | raise ValueError("RetrieverGraph not initialized.") |
126 | | - urls: list[str] = [] |
127 | | - context_list: list[str] = [] |
128 | | - context_sources: list[ContextSource] = [] |
129 | 238 |
|
130 | | - if ( |
131 | | - isinstance(output, list) |
132 | | - and len(output) > 2 |
133 | | - and "generate" in output[-1] |
134 | | - and "messages" in output[-1]["generate"] |
135 | | - and len(output[-1]["generate"]["messages"]) > 0 |
136 | | - ): |
137 | | - llm_response = output[-1]["generate"]["messages"][0] |
138 | | - tools = output[0]["agent"]["tools"] |
139 | | - print(output) |
140 | | - |
141 | | - for tool_index, tool in enumerate(tools): |
142 | | - """ |
143 | | - output schema: |
144 | | - [ |
145 | | - "agent": {"tools": ["tool1", "tool2", ...]}, |
146 | | - "tool1": {"urls": ["url1", "url2", ...], "context_list": ["context1", "context2", ...]}, |
147 | | - "tool2": {"urls": ["url1", "url2", ...], "context_list": ["context1", "context2", ...]}, |
148 | | - "generate": "messages": ["response1", "response2", ...] |
149 | | - ] |
150 | | - """ |
151 | | - urls = list(output[tool_index + 1].values())[0]["urls"] |
152 | | - context_list = list(output[tool_index + 1].values())[0]["context_list"] |
153 | | - |
154 | | - for _url, _context in zip(urls, context_list): |
155 | | - context_sources.append(ContextSource(context=_context, source=_url)) |
156 | | - else: |
157 | | - llm_response = "LLM response extraction failed" |
158 | | - logging.error("LLM response extraction failed") |
| 239 | + # Use the extracted function to parse agent output |
| 240 | + llm_response, context_sources, tools = parse_agent_output(output) |
159 | 241 |
|
| 242 | + response: dict[str, Any] |
160 | 243 | if user_input.list_sources and user_input.list_context: |
161 | 244 | response = { |
162 | 245 | "response": llm_response, |
163 | 246 | "context_sources": context_sources, |
164 | | - "tool": tools, |
| 247 | + "tools": tools, |
165 | 248 | } |
166 | 249 | elif user_input.list_sources: |
167 | 250 | response = { |
168 | 251 | "response": llm_response, |
169 | 252 | "context_sources": [ |
170 | 253 | ContextSource(context="", source=cs.source) for cs in context_sources |
171 | 254 | ], |
172 | | - "tool": tools, |
| 255 | + "tools": tools, |
173 | 256 | } |
174 | 257 | elif user_input.list_context: |
175 | 258 | response = { |
176 | 259 | "response": llm_response, |
177 | 260 | "context_sources": [ |
178 | 261 | ContextSource(context=cs.context, source="") for cs in context_sources |
179 | 262 | ], |
180 | | - "tool": tools, |
| 263 | + "tools": tools, |
181 | 264 | } |
182 | 265 | else: |
183 | 266 | response = { |
184 | 267 | "response": llm_response, |
185 | 268 | "context_sources": [ContextSource(context="", source="")], |
186 | | - "tool": tools, |
| 269 | + "tools": tools, |
187 | 270 | } |
188 | 271 |
|
189 | 272 | return ChatResponse(**response) |
|
0 commit comments