Skip to content

Commit e82182f

Browse files
Merge pull request #50 from Contrast-Security-OSS/AIML-84_extend_litellm_for_prompt_caching
AIML-84 Extend Google ADK's LiteLlm classes for prompt caching
2 parents ab374df + 1c8caa0 commit e82182f

File tree

5 files changed

+1350
-59
lines changed

5 files changed

+1350
-59
lines changed

src/agent_handler.py

Lines changed: 123 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,19 @@
4141

4242
config = get_config()
4343

44+
# Suppress warnings before importing libraries that might trigger them
45+
warnings.filterwarnings('ignore', category=UserWarning)
46+
# Suppress specific Pydantic field shadowing warning from ADK library
47+
warnings.filterwarnings('ignore', message='Field name "config_type" in "SequentialAgent" shadows an attribute in parent "BaseAgent"')
48+
4449
# --- ADK Setup (Conditional Import) ---
4550
ADK_AVAILABLE = False
4651

4752
try:
4853
from google.adk.agents import Agent
4954
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
50-
from google.adk.models.lite_llm import LiteLlm
55+
from src.extensions.smartfix_litellm import SmartFixLiteLlm
56+
from src.extensions.smartfix_llm_agent import SmartFixLlmAgent
5157
from google.adk.runners import Runner
5258
from google.adk.sessions import InMemorySessionService
5359
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset, StdioServerParameters, StdioConnectionParams
@@ -66,11 +72,62 @@
6672
if not config.testing:
6773
sys.exit(1) # Only exit in production, not in tests
6874

69-
warnings.filterwarnings('ignore', category=UserWarning)
75+
# Configure library loggers to reduce noise
7076
library_logger = logging.getLogger("google_adk.google.adk.tools.base_authenticated_tool")
7177
library_logger.setLevel(logging.ERROR)
7278

7379

80+
async def _create_mcp_toolset(target_folder_str: str) -> MCPToolset:
81+
"""Create MCP toolset with platform-specific configuration."""
82+
if platform.system() == 'Windows':
83+
connection_timeout = 180
84+
debug_log("Using Windows-specific MCP connection settings")
85+
else:
86+
connection_timeout = 120
87+
88+
return MCPToolset(
89+
connection_params=StdioConnectionParams(
90+
server_params=StdioServerParameters(
91+
command='npx',
92+
args=[
93+
'-y', # Arguments for the command
94+
'--cache', '/tmp/.npm-cache', # Use explicit cache directory
95+
'--prefer-offline', # Try to use cached packages first
96+
'@modelcontextprotocol/[email protected]',
97+
target_folder_str,
98+
],
99+
),
100+
timeout=connection_timeout,
101+
)
102+
)
103+
104+
105+
async def _get_tools_timeout() -> float:
106+
"""Get platform-specific timeout for MCP tools connection."""
107+
if platform.system() == 'Windows':
108+
return 120.0 # Much longer timeout for Windows
109+
else:
110+
return 30.0 # Increased timeout for Linux due to npm issues
111+
112+
113+
async def _clear_npm_cache_if_needed(attempt: int, max_retries: int):
114+
"""Clear npm cache on second retry if needed."""
115+
if attempt == 2:
116+
debug_log("Clearing npm cache due to repeated failures...")
117+
try:
118+
import subprocess
119+
subprocess.run(['npm', 'cache', 'clean', '--force'],
120+
capture_output=True, timeout=30)
121+
debug_log("NPM cache cleared successfully")
122+
except Exception as cache_error:
123+
debug_log(f"Failed to clear npm cache: {cache_error}")
124+
125+
126+
async def _attempt_mcp_connection(fs_tools: MCPToolset, get_tools_timeout: float) -> List:
127+
"""Attempt to connect to MCP server and get tools."""
128+
return await asyncio.wait_for(fs_tools.get_tools(), timeout=get_tools_timeout)
129+
130+
74131
async def get_mcp_tools(target_folder: Path, remediation_id: str) -> MCPToolset:
75132
"""Connects to MCP servers (Filesystem)"""
76133
debug_log("Attempting to connect to MCP servers...")
@@ -80,29 +137,39 @@ async def get_mcp_tools(target_folder: Path, remediation_id: str) -> MCPToolset:
80137
try:
81138
debug_log("Connecting to MCP Filesystem server...")
82139

83-
fs_tools = MCPToolset(
84-
connection_params=StdioConnectionParams(
85-
server_params=StdioServerParameters(
86-
command='npx',
87-
args=[
88-
'-y', # Arguments for the command
89-
'@modelcontextprotocol/[email protected]',
90-
target_folder_str,
91-
],
92-
),
93-
timeout=50,
94-
)
95-
)
140+
fs_tools = await _create_mcp_toolset(target_folder_str)
141+
get_tools_timeout = await _get_tools_timeout()
96142

97143
debug_log("Getting tools list from Filesystem MCP server...")
98-
# Use a longer timeout on Windows
99-
timeout_seconds = 30.0 if platform.system() == 'Windows' else 10.0
100-
debug_log(f"Using {timeout_seconds} second timeout for get_tools")
144+
debug_log(f"Using {get_tools_timeout} second timeout for get_tools")
145+
146+
# Add retry mechanism for MCP connection reliability across all platforms
147+
max_retries = 3
148+
last_error = None
101149

102-
# Wrap the get_tools call in wait_for to apply a timeout
103-
tools_list = await asyncio.wait_for(fs_tools.get_tools(), timeout=timeout_seconds)
150+
for attempt in range(max_retries):
151+
try:
152+
if attempt > 0:
153+
debug_log(f"Retrying MCP connection (attempt {attempt + 1}/{max_retries})")
154+
await _clear_npm_cache_if_needed(attempt, max_retries)
155+
# Wait a bit before retry to let any broken connections clean up
156+
await asyncio.sleep(2)
157+
158+
# Wrap the get_tools call in wait_for to apply a timeout
159+
tools_list = await _attempt_mcp_connection(fs_tools, get_tools_timeout)
160+
debug_log(f"Connected to Filesystem MCP server, got {len(tools_list)} tools")
161+
break # Success, exit retry loop
162+
163+
except (asyncio.TimeoutError, asyncio.CancelledError, ConnectionError) as retry_error:
164+
last_error = retry_error
165+
debug_log(f"MCP connection attempt {attempt + 1} failed: {type(retry_error).__name__}: {str(retry_error)}")
166+
if attempt == max_retries - 1:
167+
# Last attempt failed, re-raise the error
168+
raise retry_error
169+
else:
170+
# This should not be reached, but just in case
171+
raise last_error if last_error else Exception("Unknown MCP connection failure")
104172

105-
debug_log(f"Connected to Filesystem MCP server, got {len(tools_list)} tools")
106173
for tool in tools_list:
107174
if hasattr(tool, 'name'):
108175
debug_log(f" - Filesystem Tool: {tool.name}")
@@ -141,13 +208,13 @@ async def create_agent(target_folder: Path, remediation_id: str, agent_type: str
141208
agent_name = f"contrast_{agent_type}_agent"
142209

143210
try:
144-
model_instance = LiteLlm(
211+
model_instance = SmartFixLiteLlm(
145212
model=config.AGENT_MODEL,
146213
temperature=0.2, # Set low temperature for more deterministic output
147214
# seed=42, # The random seed for reproducibility (not supported by bedrock/anthropic atm call throws error)
148215
stream_options={"include_usage": True}
149216
)
150-
root_agent = Agent(
217+
root_agent = SmartFixLlmAgent(
151218
model=model_instance,
152219
name=agent_name,
153220
instruction=agent_instruction,
@@ -207,25 +274,7 @@ async def _check_event_limit(event_count: int, agent_type: str, max_events_limit
207274
return None, False
208275

209276

210-
def _extract_token_usage(event) -> tuple:
211-
"""Extract token usage information from event metadata."""
212-
total_tokens = 0
213-
prompt_tokens = 0
214-
output_tokens = 0
215-
216-
if event.usage_metadata is not None:
217-
debug_log(f"Event usage metadata for this message: {event.usage_metadata}")
218-
if hasattr(event.usage_metadata, "total_token_count"):
219-
total_tokens = event.usage_metadata.total_token_count
220-
if hasattr(event.usage_metadata, "prompt_token_count"):
221-
prompt_tokens = event.usage_metadata.prompt_token_count
222-
if total_tokens > 0 and prompt_tokens > 0:
223-
output_tokens = total_tokens - prompt_tokens
224-
225-
return total_tokens, prompt_tokens, output_tokens
226-
227-
228-
def _process_agent_content(event, agent_type: str, prompt_tokens: int, output_tokens: int, total_tokens: int) -> str:
277+
def _process_agent_content(event, agent_type: str) -> str:
229278
"""Process agent content/message from event."""
230279
if not event.content:
231280
return None
@@ -238,7 +287,6 @@ def _process_agent_content(event, agent_type: str, prompt_tokens: int, output_to
238287

239288
if message_text:
240289
log(f"\n*** {agent_type.upper()} Agent Message: \033[1;36m {message_text} \033[0m")
241-
log(f"Tokens statistics. prompt tokens: {prompt_tokens}, output tokens {output_tokens}, total tokens: {total_tokens}")
242290
return message_text
243291

244292
return None
@@ -293,21 +341,18 @@ async def _process_agent_event(event, event_count: int, agent_type: str, max_eve
293341
# Check if we've exceeded the event limit
294342
final_response, should_break = await _check_event_limit(event_count, agent_type, max_events_limit)
295343
if should_break:
296-
return 0, 0, 0, final_response, should_break
297-
298-
# Extract token usage information
299-
total_tokens, prompt_tokens, output_tokens = _extract_token_usage(event)
344+
return final_response, should_break
300345

301346
# Process agent content/message
302-
content_response = _process_agent_content(event, agent_type, prompt_tokens, output_tokens, total_tokens)
347+
content_response = _process_agent_content(event, agent_type)
303348
if content_response:
304349
final_response = content_response
305350

306351
# Process function calls and responses
307352
_process_function_calls(event, agent_type, agent_tool_calls_telemetry)
308353
_process_function_responses(event, agent_type, agent_tool_calls_telemetry)
309354

310-
return total_tokens, prompt_tokens, output_tokens, final_response, False
355+
return final_response, False
311356

312357

313358
async def _handle_agent_exception(e: Exception, events_async, remediation_id: str) -> bool:
@@ -335,7 +380,7 @@ async def _handle_agent_exception(e: Exception, events_async, remediation_id: st
335380
return is_asyncio_error
336381

337382

338-
async def process_agent_run(runner, session, user_query, remediation_id: str, agent_type: str = None) -> str:
383+
async def process_agent_run(runner, agent, session, user_query, remediation_id: str, agent_type: str = None) -> str:
339384
"""Runs the agent, allowing it to use tools, and returns the final text response."""
340385
agent_event_actions = []
341386
start_time = datetime.datetime.now()
@@ -347,7 +392,6 @@ async def process_agent_run(runner, session, user_query, remediation_id: str, ag
347392

348393
# Initialize tracking variables
349394
event_count = 0
350-
total_tokens = 0
351395
final_response = "AI agent did not provide a final summary."
352396
max_events_limit = config.MAX_EVENTS_PER_AGENT
353397

@@ -369,12 +413,9 @@ async def process_agent_run(runner, session, user_query, remediation_id: str, ag
369413
event_count += 1
370414

371415
# Process the event and get updated state
372-
event_tokens, event_prompt_tokens, event_output_tokens, event_response, should_break = \
416+
event_response, should_break = \
373417
await _process_agent_event(event, event_count, agent_type, max_events_limit, agent_tool_calls_telemetry)
374418

375-
# Update tracking variables
376-
if event_tokens > 0:
377-
total_tokens = event_tokens
378419
if event_response:
379420
final_response = event_response
380421

@@ -417,6 +458,28 @@ async def process_agent_run(runner, session, user_query, remediation_id: str, ag
417458
debug_log(f"Closing MCP server connections for {agent_type.upper()} agent...")
418459
log(f"{agent_type.upper()} agent run finished.")
419460

461+
# Get accumulated statistics for telemetry
462+
try:
463+
stats_data = agent.gather_accumulated_stats_dict()
464+
debug_log(agent.gather_accumulated_stats()) # Log the JSON formatted version
465+
466+
# Extract telemetry values directly from the dictionary
467+
total_tokens = stats_data.get("token_usage", {}).get("total_tokens", 0)
468+
raw_total_cost = stats_data.get("cost_analysis", {}).get("total_cost", 0.0)
469+
470+
# Remove "$" prefix if present and convert to float
471+
if isinstance(raw_total_cost, str) and raw_total_cost.startswith("$"):
472+
total_cost = float(raw_total_cost[1:])
473+
elif isinstance(raw_total_cost, str):
474+
total_cost = float(raw_total_cost)
475+
else:
476+
total_cost = raw_total_cost
477+
except (ValueError, KeyError, AttributeError) as e:
478+
# Fallback values if stats retrieval fails
479+
debug_log(f"Could not retrieve statistics: {e}")
480+
total_tokens = 0
481+
total_cost = 0.0
482+
420483
# Directly assign toolCalls rather than appending, to avoid nested arrays
421484
agent_event_telemetry["toolCalls"] = agent_tool_calls_telemetry
422485
agent_event_actions.append(agent_event_telemetry)
@@ -427,8 +490,8 @@ async def process_agent_run(runner, session, user_query, remediation_id: str, ag
427490
"agentType": agent_type.upper(),
428491
"result": agent_run_result,
429492
"actions": agent_event_actions,
430-
"totalTokens": total_tokens, # Use the tracked token count from latest event
431-
"totalCost": 0.0 # Placeholder still as cost calculation would need more info
493+
"totalTokens": total_tokens,
494+
"totalCost": total_cost
432495
}
433496
telemetry_handler.add_agent_event(agent_event_payload)
434497

@@ -917,8 +980,7 @@ def filter(self, record):
917980
)
918981

919982
# Pass the full model ID (though not used for cost calculation anymore, kept for consistency if needed elsewhere)
920-
summary = await process_agent_run(runner, session, query, remediation_id, agent_type)
921-
983+
summary = await process_agent_run(runner, agent, session, query, remediation_id, agent_type)
922984
return summary
923985

924986
# This patch is now handled in src/asyncio_win_patch.py and called from main.py
@@ -933,3 +995,5 @@ def _patched_loop_check_closed(self):
933995
return # ignore this error
934996
raise
935997
asyncio.BaseEventLoop._check_closed = _patched_loop_check_closed
998+
999+
# %%

0 commit comments

Comments
 (0)