22import logging
33from dotenv import load_dotenv
44
5+ from typing import Any
56from fastapi import APIRouter
67from langchain_google_vertexai import ChatVertexAI
78from langchain_google_genai import ChatGoogleGenerativeAI
8889
8990router = APIRouter (prefix = "/graphs" , tags = ["graphs" ])
9091
92+
93+ def extract_rag_context_sources (output : list ) -> list [ContextSource ]:
94+ """Extract context sources from RAG agent output."""
95+ context_sources = []
96+ for element in output [1 :- 1 ]: # Skip first (classify) and last (generate) nodes
97+ if isinstance (element , dict ):
98+ for key , value in element .items ():
99+ if key .startswith ("retrieve_" ) and isinstance (value , dict ):
100+ urls = value .get ("urls" , [])
101+ context = value .get ("context" , "" )
102+ # Create context sources from urls and context
103+ for url in urls :
104+ context_sources .append (
105+ ContextSource (context = context , source = url )
106+ )
107+ return context_sources
108+
109+
110+ def extract_mcp_context_sources (output : list ) -> tuple [list [ContextSource ], list [str ]]:
111+ """Extract context sources and tools from MCP/arch agent output."""
112+ context_sources = []
113+ tools = []
114+
115+ if "agent" in output [0 ] and "tools" in output [0 ]["agent" ]:
116+ tools = output [0 ]["agent" ]["tools" ]
117+ for tool_index in range (len (tools )):
118+ tool_output = list (output [tool_index + 1 ].values ())[0 ]
119+ urls = tool_output .get ("urls" , [])
120+ context_list = tool_output .get ("context_list" , [])
121+ for _url , _context in zip (urls , context_list ):
122+ context_sources .append (ContextSource (context = _context , source = _url ))
123+
124+ return context_sources , tools
125+
126+
127+ def validate_output_structure (output : Any ) -> bool :
128+ """Validate that output has the expected structure."""
129+ return isinstance (output , list ) and len (output ) > 2 and isinstance (output [- 1 ], dict )
130+
131+
132+ def log_invalid_output (output : Any ) -> None :
133+ """Log details about invalid output structure."""
134+ logging .error (
135+ f"Invalid output structure: type={ type (output )} , len={ len (output ) if isinstance (output , list ) else 'N/A' } "
136+ )
137+ if isinstance (output , list ) and len (output ) > 0 :
138+ logging .error (
139+ f"Last element keys: { output [- 1 ].keys () if isinstance (output [- 1 ], dict ) else 'Not a dict' } "
140+ )
141+
142+
143+ def get_agent_type (output : list ) -> tuple [bool , str ]:
144+ """Determine agent type from output structure.
145+ Returns: (is_rag_agent, generate_key)
146+ """
147+ is_rag_agent = "rag_generate" in output [- 1 ]
148+ generate_key = "rag_generate" if is_rag_agent else "generate"
149+ return is_rag_agent , generate_key
150+
151+
152+ def extract_llm_response (output : list , generate_key : str ) -> str | None :
153+ """Extract LLM response from output if available."""
154+ if generate_key not in output [- 1 ]:
155+ logging .error (f"Missing { generate_key } key" )
156+ return None
157+
158+ generate_data = output [- 1 ][generate_key ]
159+
160+ if "messages" not in generate_data :
161+ logging .error (f"Missing messages in { generate_key } " )
162+ return None
163+
164+ messages = generate_data ["messages" ]
165+ if not messages or len (messages ) == 0 :
166+ logging .error ("No messages in generate output" )
167+ return None
168+
169+ return str (messages [0 ])
170+
171+
172+ def parse_agent_output (output : list ) -> tuple [str , list [ContextSource ], list [str ]]:
173+ """
174+ Parse agent output and extract response, context sources, and tools.
175+ """
176+ # Default return values
177+ default_response = "LLM response extraction failed"
178+ context_sources : list [ContextSource ] = []
179+ tools : list [str ] = []
180+
181+ # Validate output structure
182+ if not validate_output_structure (output ):
183+ log_invalid_output (output )
184+ return default_response , context_sources , tools
185+
186+ # Determine agent type
187+ is_rag_agent , generate_key = get_agent_type (output )
188+
189+ # Extract LLM response
190+ llm_response = extract_llm_response (output , generate_key )
191+ if llm_response is None :
192+ return default_response , context_sources , tools
193+
194+ # Extract context sources based on agent type
195+ if is_rag_agent :
196+ context_sources = extract_rag_context_sources (output )
197+ else :
198+ context_sources , tools = extract_mcp_context_sources (output )
199+
200+ return llm_response , context_sources , tools
201+
202+
91203rg = RetrieverGraph (
92204 llm_model = llm ,
93205 embeddings_config = embeddings_config ,
96208 inbuilt_tool_calling = True ,
97209 fast_mode = fast_mode ,
98210 debug = debug ,
99- enable_mcp = enable_mcp
211+ enable_mcp = enable_mcp ,
100212)
101213rg .initialize ()
102214
@@ -123,67 +235,38 @@ async def get_agent_response(user_input: UserInput) -> ChatResponse:
123235 output = list (rg .graph .stream (inputs , stream_mode = "updates" ))
124236 else :
125237 raise ValueError ("RetrieverGraph not initialized." )
126- urls : list [str ] = []
127- context_list : list [str ] = []
128- context_sources : list [ContextSource ] = []
129238
130- if (
131- isinstance (output , list )
132- and len (output ) > 2
133- and "generate" in output [- 1 ]
134- and "messages" in output [- 1 ]["generate" ]
135- and len (output [- 1 ]["generate" ]["messages" ]) > 0
136- ):
137- llm_response = output [- 1 ]["generate" ]["messages" ][0 ]
138- tools = output [0 ]["agent" ]["tools" ]
139- print (output )
140-
141- for tool_index , tool in enumerate (tools ):
142- """
143- output schema:
144- [
145- "agent": {"tools": ["tool1", "tool2", ...]},
146- "tool1": {"urls": ["url1", "url2", ...], "context_list": ["context1", "context2", ...]},
147- "tool2": {"urls": ["url1", "url2", ...], "context_list": ["context1", "context2", ...]},
148- "generate": "messages": ["response1", "response2", ...]
149- ]
150- """
151- urls = list (output [tool_index + 1 ].values ())[0 ]["urls" ]
152- context_list = list (output [tool_index + 1 ].values ())[0 ]["context_list" ]
153-
154- for _url , _context in zip (urls , context_list ):
155- context_sources .append (ContextSource (context = _context , source = _url ))
156- else :
157- llm_response = "LLM response extraction failed"
158- logging .error ("LLM response extraction failed" )
239+ # Use the extracted function to parse agent output
240+ llm_response , context_sources , tools = parse_agent_output (output )
159241
242+ response : dict [str , Any ]
160243 if user_input .list_sources and user_input .list_context :
161244 response = {
162245 "response" : llm_response ,
163246 "context_sources" : context_sources ,
164- "tool " : tools ,
247+ "tools " : tools ,
165248 }
166249 elif user_input .list_sources :
167250 response = {
168251 "response" : llm_response ,
169252 "context_sources" : [
170253 ContextSource (context = "" , source = cs .source ) for cs in context_sources
171254 ],
172- "tool " : tools ,
255+ "tools " : tools ,
173256 }
174257 elif user_input .list_context :
175258 response = {
176259 "response" : llm_response ,
177260 "context_sources" : [
178261 ContextSource (context = cs .context , source = "" ) for cs in context_sources
179262 ],
180- "tool " : tools ,
263+ "tools " : tools ,
181264 }
182265 else :
183266 response = {
184267 "response" : llm_response ,
185268 "context_sources" : [ContextSource (context = "" , source = "" )],
186- "tool " : tools ,
269+ "tools " : tools ,
187270 }
188271
189272 return ChatResponse (** response )
0 commit comments