4141
4242config = get_config ()
4343
44+ # Suppress warnings before importing libraries that might trigger them
45+ warnings .filterwarnings ('ignore' , category = UserWarning )
46+ # Suppress specific Pydantic field shadowing warning from ADK library
47+ warnings .filterwarnings ('ignore' , message = 'Field name "config_type" in "SequentialAgent" shadows an attribute in parent "BaseAgent"' )
48+
4449# --- ADK Setup (Conditional Import) ---
4550ADK_AVAILABLE = False
4651
4752try :
4853 from google .adk .agents import Agent
4954 from google .adk .artifacts .in_memory_artifact_service import InMemoryArtifactService
50- from google .adk .models .lite_llm import LiteLlm
55+ from src .extensions .smartfix_litellm import SmartFixLiteLlm
56+ from src .extensions .smartfix_llm_agent import SmartFixLlmAgent
5157 from google .adk .runners import Runner
5258 from google .adk .sessions import InMemorySessionService
5359 from google .adk .tools .mcp_tool .mcp_toolset import MCPToolset , StdioServerParameters , StdioConnectionParams
6672 if not config .testing :
6773 sys .exit (1 ) # Only exit in production, not in tests
6874
69- warnings . filterwarnings ( 'ignore' , category = UserWarning )
75+ # Configure library loggers to reduce noise
7076library_logger = logging .getLogger ("google_adk.google.adk.tools.base_authenticated_tool" )
7177library_logger .setLevel (logging .ERROR )
7278
7379
80+ async def _create_mcp_toolset (target_folder_str : str ) -> MCPToolset :
81+ """Create MCP toolset with platform-specific configuration."""
82+ if platform .system () == 'Windows' :
83+ connection_timeout = 180
84+ debug_log ("Using Windows-specific MCP connection settings" )
85+ else :
86+ connection_timeout = 120
87+
88+ return MCPToolset (
89+ connection_params = StdioConnectionParams (
90+ server_params = StdioServerParameters (
91+ command = 'npx' ,
92+ args = [
93+ '-y' , # Arguments for the command
94+ '--cache' , '/tmp/.npm-cache' , # Use explicit cache directory
95+ '--prefer-offline' , # Try to use cached packages first
96+ '@modelcontextprotocol/[email protected] ' ,
97+ target_folder_str ,
98+ ],
99+ ),
100+ timeout = connection_timeout ,
101+ )
102+ )
103+
104+
105+ async def _get_tools_timeout () -> float :
106+ """Get platform-specific timeout for MCP tools connection."""
107+ if platform .system () == 'Windows' :
108+ return 120.0 # Much longer timeout for Windows
109+ else :
110+ return 30.0 # Increased timeout for Linux due to npm issues
111+
112+
113+ async def _clear_npm_cache_if_needed (attempt : int , max_retries : int ):
114+ """Clear npm cache on second retry if needed."""
115+ if attempt == 2 :
116+ debug_log ("Clearing npm cache due to repeated failures..." )
117+ try :
118+ import subprocess
119+ subprocess .run (['npm' , 'cache' , 'clean' , '--force' ],
120+ capture_output = True , timeout = 30 )
121+ debug_log ("NPM cache cleared successfully" )
122+ except Exception as cache_error :
123+ debug_log (f"Failed to clear npm cache: { cache_error } " )
124+
125+
126+ async def _attempt_mcp_connection (fs_tools : MCPToolset , get_tools_timeout : float ) -> List :
127+ """Attempt to connect to MCP server and get tools."""
128+ return await asyncio .wait_for (fs_tools .get_tools (), timeout = get_tools_timeout )
129+
130+
74131async def get_mcp_tools (target_folder : Path , remediation_id : str ) -> MCPToolset :
75132 """Connects to MCP servers (Filesystem)"""
76133 debug_log ("Attempting to connect to MCP servers..." )
@@ -80,29 +137,39 @@ async def get_mcp_tools(target_folder: Path, remediation_id: str) -> MCPToolset:
80137 try :
81138 debug_log ("Connecting to MCP Filesystem server..." )
82139
83- fs_tools = MCPToolset (
84- connection_params = StdioConnectionParams (
85- server_params = StdioServerParameters (
86- command = 'npx' ,
87- args = [
88- '-y' , # Arguments for the command
89- '@modelcontextprotocol/[email protected] ' ,
90- target_folder_str ,
91- ],
92- ),
93- timeout = 50 ,
94- )
95- )
140+ fs_tools = await _create_mcp_toolset (target_folder_str )
141+ get_tools_timeout = await _get_tools_timeout ()
96142
97143 debug_log ("Getting tools list from Filesystem MCP server..." )
98- # Use a longer timeout on Windows
99- timeout_seconds = 30.0 if platform .system () == 'Windows' else 10.0
100- debug_log (f"Using { timeout_seconds } second timeout for get_tools" )
144+ debug_log (f"Using { get_tools_timeout } second timeout for get_tools" )
145+
146+ # Add retry mechanism for MCP connection reliability across all platforms
147+ max_retries = 3
148+ last_error = None
101149
102- # Wrap the get_tools call in wait_for to apply a timeout
103- tools_list = await asyncio .wait_for (fs_tools .get_tools (), timeout = timeout_seconds )
150+ for attempt in range (max_retries ):
151+ try :
152+ if attempt > 0 :
153+ debug_log (f"Retrying MCP connection (attempt { attempt + 1 } /{ max_retries } )" )
154+ await _clear_npm_cache_if_needed (attempt , max_retries )
155+ # Wait a bit before retry to let any broken connections clean up
156+ await asyncio .sleep (2 )
157+
158+ # Wrap the get_tools call in wait_for to apply a timeout
159+ tools_list = await _attempt_mcp_connection (fs_tools , get_tools_timeout )
160+ debug_log (f"Connected to Filesystem MCP server, got { len (tools_list )} tools" )
161+ break # Success, exit retry loop
162+
163+ except (asyncio .TimeoutError , asyncio .CancelledError , ConnectionError ) as retry_error :
164+ last_error = retry_error
165+ debug_log (f"MCP connection attempt { attempt + 1 } failed: { type (retry_error ).__name__ } : { str (retry_error )} " )
166+ if attempt == max_retries - 1 :
167+ # Last attempt failed, re-raise the error
168+ raise retry_error
169+ else :
170+ # This should not be reached, but just in case
171+ raise last_error if last_error else Exception ("Unknown MCP connection failure" )
104172
105- debug_log (f"Connected to Filesystem MCP server, got { len (tools_list )} tools" )
106173 for tool in tools_list :
107174 if hasattr (tool , 'name' ):
108175 debug_log (f" - Filesystem Tool: { tool .name } " )
@@ -141,13 +208,13 @@ async def create_agent(target_folder: Path, remediation_id: str, agent_type: str
141208 agent_name = f"contrast_{ agent_type } _agent"
142209
143210 try :
144- model_instance = LiteLlm (
211+ model_instance = SmartFixLiteLlm (
145212 model = config .AGENT_MODEL ,
146213 temperature = 0.2 , # Set low temperature for more deterministic output
147214 # seed=42, # The random seed for reproducibility (not supported by bedrock/anthropic atm call throws error)
148215 stream_options = {"include_usage" : True }
149216 )
150- root_agent = Agent (
217+ root_agent = SmartFixLlmAgent (
151218 model = model_instance ,
152219 name = agent_name ,
153220 instruction = agent_instruction ,
@@ -207,25 +274,7 @@ async def _check_event_limit(event_count: int, agent_type: str, max_events_limit
207274 return None , False
208275
209276
210- def _extract_token_usage (event ) -> tuple :
211- """Extract token usage information from event metadata."""
212- total_tokens = 0
213- prompt_tokens = 0
214- output_tokens = 0
215-
216- if event .usage_metadata is not None :
217- debug_log (f"Event usage metadata for this message: { event .usage_metadata } " )
218- if hasattr (event .usage_metadata , "total_token_count" ):
219- total_tokens = event .usage_metadata .total_token_count
220- if hasattr (event .usage_metadata , "prompt_token_count" ):
221- prompt_tokens = event .usage_metadata .prompt_token_count
222- if total_tokens > 0 and prompt_tokens > 0 :
223- output_tokens = total_tokens - prompt_tokens
224-
225- return total_tokens , prompt_tokens , output_tokens
226-
227-
228- def _process_agent_content (event , agent_type : str , prompt_tokens : int , output_tokens : int , total_tokens : int ) -> str :
277+ def _process_agent_content (event , agent_type : str ) -> str :
229278 """Process agent content/message from event."""
230279 if not event .content :
231280 return None
@@ -238,7 +287,6 @@ def _process_agent_content(event, agent_type: str, prompt_tokens: int, output_to
238287
239288 if message_text :
240289 log (f"\n *** { agent_type .upper ()} Agent Message: \033 [1;36m { message_text } \033 [0m" )
241- log (f"Tokens statistics. prompt tokens: { prompt_tokens } , output tokens { output_tokens } , total tokens: { total_tokens } " )
242290 return message_text
243291
244292 return None
@@ -293,21 +341,18 @@ async def _process_agent_event(event, event_count: int, agent_type: str, max_eve
293341 # Check if we've exceeded the event limit
294342 final_response , should_break = await _check_event_limit (event_count , agent_type , max_events_limit )
295343 if should_break :
296- return 0 , 0 , 0 , final_response , should_break
297-
298- # Extract token usage information
299- total_tokens , prompt_tokens , output_tokens = _extract_token_usage (event )
344+ return final_response , should_break
300345
301346 # Process agent content/message
302- content_response = _process_agent_content (event , agent_type , prompt_tokens , output_tokens , total_tokens )
347+ content_response = _process_agent_content (event , agent_type )
303348 if content_response :
304349 final_response = content_response
305350
306351 # Process function calls and responses
307352 _process_function_calls (event , agent_type , agent_tool_calls_telemetry )
308353 _process_function_responses (event , agent_type , agent_tool_calls_telemetry )
309354
310- return total_tokens , prompt_tokens , output_tokens , final_response , False
355+ return final_response , False
311356
312357
313358async def _handle_agent_exception (e : Exception , events_async , remediation_id : str ) -> bool :
@@ -335,7 +380,7 @@ async def _handle_agent_exception(e: Exception, events_async, remediation_id: st
335380 return is_asyncio_error
336381
337382
338- async def process_agent_run (runner , session , user_query , remediation_id : str , agent_type : str = None ) -> str :
383+ async def process_agent_run (runner , agent , session , user_query , remediation_id : str , agent_type : str = None ) -> str :
339384 """Runs the agent, allowing it to use tools, and returns the final text response."""
340385 agent_event_actions = []
341386 start_time = datetime .datetime .now ()
@@ -347,7 +392,6 @@ async def process_agent_run(runner, session, user_query, remediation_id: str, ag
347392
348393 # Initialize tracking variables
349394 event_count = 0
350- total_tokens = 0
351395 final_response = "AI agent did not provide a final summary."
352396 max_events_limit = config .MAX_EVENTS_PER_AGENT
353397
@@ -369,12 +413,9 @@ async def process_agent_run(runner, session, user_query, remediation_id: str, ag
369413 event_count += 1
370414
371415 # Process the event and get updated state
372- event_tokens , event_prompt_tokens , event_output_tokens , event_response , should_break = \
416+ event_response , should_break = \
373417 await _process_agent_event (event , event_count , agent_type , max_events_limit , agent_tool_calls_telemetry )
374418
375- # Update tracking variables
376- if event_tokens > 0 :
377- total_tokens = event_tokens
378419 if event_response :
379420 final_response = event_response
380421
@@ -417,6 +458,28 @@ async def process_agent_run(runner, session, user_query, remediation_id: str, ag
417458 debug_log (f"Closing MCP server connections for { agent_type .upper ()} agent..." )
418459 log (f"{ agent_type .upper ()} agent run finished." )
419460
461+ # Get accumulated statistics for telemetry
462+ try :
463+ stats_data = agent .gather_accumulated_stats_dict ()
464+ debug_log (agent .gather_accumulated_stats ()) # Log the JSON formatted version
465+
466+ # Extract telemetry values directly from the dictionary
467+ total_tokens = stats_data .get ("token_usage" , {}).get ("total_tokens" , 0 )
468+ raw_total_cost = stats_data .get ("cost_analysis" , {}).get ("total_cost" , 0.0 )
469+
470+ # Remove "$" prefix if present and convert to float
471+ if isinstance (raw_total_cost , str ) and raw_total_cost .startswith ("$" ):
472+ total_cost = float (raw_total_cost [1 :])
473+ elif isinstance (raw_total_cost , str ):
474+ total_cost = float (raw_total_cost )
475+ else :
476+ total_cost = raw_total_cost
477+ except (ValueError , KeyError , AttributeError ) as e :
478+ # Fallback values if stats retrieval fails
479+ debug_log (f"Could not retrieve statistics: { e } " )
480+ total_tokens = 0
481+ total_cost = 0.0
482+
420483 # Directly assign toolCalls rather than appending, to avoid nested arrays
421484 agent_event_telemetry ["toolCalls" ] = agent_tool_calls_telemetry
422485 agent_event_actions .append (agent_event_telemetry )
@@ -427,8 +490,8 @@ async def process_agent_run(runner, session, user_query, remediation_id: str, ag
427490 "agentType" : agent_type .upper (),
428491 "result" : agent_run_result ,
429492 "actions" : agent_event_actions ,
430- "totalTokens" : total_tokens , # Use the tracked token count from latest event
431- "totalCost" : 0.0 # Placeholder still as cost calculation would need more info
493+ "totalTokens" : total_tokens ,
494+ "totalCost" : total_cost
432495 }
433496 telemetry_handler .add_agent_event (agent_event_payload )
434497
@@ -917,8 +980,7 @@ def filter(self, record):
917980 )
918981
919982 # Pass the full model ID (though not used for cost calculation anymore, kept for consistency if needed elsewhere)
920- summary = await process_agent_run (runner , session , query , remediation_id , agent_type )
921-
983+ summary = await process_agent_run (runner , agent , session , query , remediation_id , agent_type )
922984 return summary
923985
924986# This patch is now handled in src/asyncio_win_patch.py and called from main.py
@@ -933,3 +995,5 @@ def _patched_loop_check_closed(self):
933995 return # ignore this error
934996 raise
935997 asyncio .BaseEventLoop ._check_closed = _patched_loop_check_closed
998+
999+ # %%
0 commit comments