Skip to content

Commit 6046e2a

Browse files
authored
fix default retriever routing behavior (#172)
* Fix CI evaluation failures by refactoring agent output parsing - Extract helper functions to parse RAG vs MCP agent outputs - Add proper validation for agent output structure - Handle both "rag_generate" and "generate" response keys - Fix field name from "tool" to "tools" in response model - Add comprehensive error logging and fallback handling - Reduce cyclomatic complexity through function extraction --------- Signed-off-by: Jack Luar <[email protected]>
1 parent 66cbfd0 commit 6046e2a

File tree

2 files changed

+125
-38
lines changed

2 files changed

+125
-38
lines changed

backend/src/agents/retriever_graph.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555

5656
def classify(self, state: AgentState) -> dict[str, list[str]]:
5757
"""Determine if architecture/config, execute, or RAG. Handle misc."""
58-
if self.inbuilt_tool_calling:
58+
if self.inbuilt_tool_calling and self.enable_mcp:
5959
question = state["messages"][-1].content
6060
model = self.llm.bind_tools(
6161
[rag_info, mcp_info, arch_info], # type: ignore
@@ -83,6 +83,10 @@ def classify(self, state: AgentState) -> dict[str, list[str]]:
8383

8484
logging.info(result)
8585
return {"agent_type": [result]}
86+
elif self.inbuilt_tool_calling and not self.enable_mcp:
87+
# When MCP is disabled but inbuilt tool calling is enabled, just use RAG
88+
logging.info("MCP disabled, defaulting to RAG agent")
89+
return {"agent_type": ["rag_agent"]}
8690
else:
8791
logging.info("classify task")
8892

backend/src/api/routers/graphs.py

Lines changed: 120 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
from dotenv import load_dotenv
44

5+
from typing import Any
56
from fastapi import APIRouter
67
from langchain_google_vertexai import ChatVertexAI
78
from langchain_google_genai import ChatGoogleGenerativeAI
@@ -88,6 +89,117 @@
8889

8990
router = APIRouter(prefix="/graphs", tags=["graphs"])
9091

92+
93+
def extract_rag_context_sources(output: list) -> list[ContextSource]:
94+
"""Extract context sources from RAG agent output."""
95+
context_sources = []
96+
for element in output[1:-1]: # Skip first (classify) and last (generate) nodes
97+
if isinstance(element, dict):
98+
for key, value in element.items():
99+
if key.startswith("retrieve_") and isinstance(value, dict):
100+
urls = value.get("urls", [])
101+
context = value.get("context", "")
102+
# Create context sources from urls and context
103+
for url in urls:
104+
context_sources.append(
105+
ContextSource(context=context, source=url)
106+
)
107+
return context_sources
108+
109+
110+
def extract_mcp_context_sources(output: list) -> tuple[list[ContextSource], list[str]]:
111+
"""Extract context sources and tools from MCP/arch agent output."""
112+
context_sources = []
113+
tools = []
114+
115+
if "agent" in output[0] and "tools" in output[0]["agent"]:
116+
tools = output[0]["agent"]["tools"]
117+
for tool_index in range(len(tools)):
118+
tool_output = list(output[tool_index + 1].values())[0]
119+
urls = tool_output.get("urls", [])
120+
context_list = tool_output.get("context_list", [])
121+
for _url, _context in zip(urls, context_list):
122+
context_sources.append(ContextSource(context=_context, source=_url))
123+
124+
return context_sources, tools
125+
126+
127+
def validate_output_structure(output: Any) -> bool:
128+
"""Validate that output has the expected structure."""
129+
return isinstance(output, list) and len(output) > 2 and isinstance(output[-1], dict)
130+
131+
132+
def log_invalid_output(output: Any) -> None:
133+
"""Log details about invalid output structure."""
134+
logging.error(
135+
f"Invalid output structure: type={type(output)}, len={len(output) if isinstance(output, list) else 'N/A'}"
136+
)
137+
if isinstance(output, list) and len(output) > 0:
138+
logging.error(
139+
f"Last element keys: {output[-1].keys() if isinstance(output[-1], dict) else 'Not a dict'}"
140+
)
141+
142+
143+
def get_agent_type(output: list) -> tuple[bool, str]:
144+
"""Determine agent type from output structure.
145+
Returns: (is_rag_agent, generate_key)
146+
"""
147+
is_rag_agent = "rag_generate" in output[-1]
148+
generate_key = "rag_generate" if is_rag_agent else "generate"
149+
return is_rag_agent, generate_key
150+
151+
152+
def extract_llm_response(output: list, generate_key: str) -> str | None:
153+
"""Extract LLM response from output if available."""
154+
if generate_key not in output[-1]:
155+
logging.error(f"Missing {generate_key} key")
156+
return None
157+
158+
generate_data = output[-1][generate_key]
159+
160+
if "messages" not in generate_data:
161+
logging.error(f"Missing messages in {generate_key}")
162+
return None
163+
164+
messages = generate_data["messages"]
165+
if not messages or len(messages) == 0:
166+
logging.error("No messages in generate output")
167+
return None
168+
169+
return str(messages[0])
170+
171+
172+
def parse_agent_output(output: list) -> tuple[str, list[ContextSource], list[str]]:
173+
"""
174+
Parse agent output and extract response, context sources, and tools.
175+
"""
176+
# Default return values
177+
default_response = "LLM response extraction failed"
178+
context_sources: list[ContextSource] = []
179+
tools: list[str] = []
180+
181+
# Validate output structure
182+
if not validate_output_structure(output):
183+
log_invalid_output(output)
184+
return default_response, context_sources, tools
185+
186+
# Determine agent type
187+
is_rag_agent, generate_key = get_agent_type(output)
188+
189+
# Extract LLM response
190+
llm_response = extract_llm_response(output, generate_key)
191+
if llm_response is None:
192+
return default_response, context_sources, tools
193+
194+
# Extract context sources based on agent type
195+
if is_rag_agent:
196+
context_sources = extract_rag_context_sources(output)
197+
else:
198+
context_sources, tools = extract_mcp_context_sources(output)
199+
200+
return llm_response, context_sources, tools
201+
202+
91203
rg = RetrieverGraph(
92204
llm_model=llm,
93205
embeddings_config=embeddings_config,
@@ -96,7 +208,7 @@
96208
inbuilt_tool_calling=True,
97209
fast_mode=fast_mode,
98210
debug=debug,
99-
enable_mcp=enable_mcp
211+
enable_mcp=enable_mcp,
100212
)
101213
rg.initialize()
102214

@@ -123,67 +235,38 @@ async def get_agent_response(user_input: UserInput) -> ChatResponse:
123235
output = list(rg.graph.stream(inputs, stream_mode="updates"))
124236
else:
125237
raise ValueError("RetrieverGraph not initialized.")
126-
urls: list[str] = []
127-
context_list: list[str] = []
128-
context_sources: list[ContextSource] = []
129238

130-
if (
131-
isinstance(output, list)
132-
and len(output) > 2
133-
and "generate" in output[-1]
134-
and "messages" in output[-1]["generate"]
135-
and len(output[-1]["generate"]["messages"]) > 0
136-
):
137-
llm_response = output[-1]["generate"]["messages"][0]
138-
tools = output[0]["agent"]["tools"]
139-
print(output)
140-
141-
for tool_index, tool in enumerate(tools):
142-
"""
143-
output schema:
144-
[
145-
"agent": {"tools": ["tool1", "tool2", ...]},
146-
"tool1": {"urls": ["url1", "url2", ...], "context_list": ["context1", "context2", ...]},
147-
"tool2": {"urls": ["url1", "url2", ...], "context_list": ["context1", "context2", ...]},
148-
"generate": "messages": ["response1", "response2", ...]
149-
]
150-
"""
151-
urls = list(output[tool_index + 1].values())[0]["urls"]
152-
context_list = list(output[tool_index + 1].values())[0]["context_list"]
153-
154-
for _url, _context in zip(urls, context_list):
155-
context_sources.append(ContextSource(context=_context, source=_url))
156-
else:
157-
llm_response = "LLM response extraction failed"
158-
logging.error("LLM response extraction failed")
239+
# Use the extracted function to parse agent output
240+
llm_response, context_sources, tools = parse_agent_output(output)
159241

242+
response: dict[str, Any]
160243
if user_input.list_sources and user_input.list_context:
161244
response = {
162245
"response": llm_response,
163246
"context_sources": context_sources,
164-
"tool": tools,
247+
"tools": tools,
165248
}
166249
elif user_input.list_sources:
167250
response = {
168251
"response": llm_response,
169252
"context_sources": [
170253
ContextSource(context="", source=cs.source) for cs in context_sources
171254
],
172-
"tool": tools,
255+
"tools": tools,
173256
}
174257
elif user_input.list_context:
175258
response = {
176259
"response": llm_response,
177260
"context_sources": [
178261
ContextSource(context=cs.context, source="") for cs in context_sources
179262
],
180-
"tool": tools,
263+
"tools": tools,
181264
}
182265
else:
183266
response = {
184267
"response": llm_response,
185268
"context_sources": [ContextSource(context="", source="")],
186-
"tool": tools,
269+
"tools": tools,
187270
}
188271

189272
return ChatResponse(**response)

0 commit comments

Comments
 (0)