Skip to content

Commit 1c8caa0

Browse files
author
Jacob Mages-Haskins
committed
AIML-84 Respond to PR feedback and rename extension files to smartfix_* & reduce filesystem MCP server timeouts
1 parent e552fae commit 1c8caa0

File tree

5 files changed

+95
-95
lines changed

5 files changed

+95
-95
lines changed

src/agent_handler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@
5252
try:
5353
from google.adk.agents import Agent
5454
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
55-
from src.extensions.extended_litellm import ExtendedLiteLlm
56-
from src.extensions.extended_llm_agent import ExtendedLlmAgent
55+
from src.extensions.smartfix_litellm import SmartFixLiteLlm
56+
from src.extensions.smartfix_llm_agent import SmartFixLlmAgent
5757
from google.adk.runners import Runner
5858
from google.adk.sessions import InMemorySessionService
5959
from google.adk.tools.mcp_tool.mcp_toolset import MCPToolset, StdioServerParameters, StdioConnectionParams
@@ -80,10 +80,10 @@
8080
async def _create_mcp_toolset(target_folder_str: str) -> MCPToolset:
8181
"""Create MCP toolset with platform-specific configuration."""
8282
if platform.system() == 'Windows':
83-
connection_timeout = 300
83+
connection_timeout = 180
8484
debug_log("Using Windows-specific MCP connection settings")
8585
else:
86-
connection_timeout = 180
86+
connection_timeout = 120
8787

8888
return MCPToolset(
8989
connection_params=StdioConnectionParams(
@@ -208,13 +208,13 @@ async def create_agent(target_folder: Path, remediation_id: str, agent_type: str
208208
agent_name = f"contrast_{agent_type}_agent"
209209

210210
try:
211-
model_instance = ExtendedLiteLlm(
211+
model_instance = SmartFixLiteLlm(
212212
model=config.AGENT_MODEL,
213213
temperature=0.2, # Set low temperature for more deterministic output
214214
# seed=42, # The random seed for reproducibility (not supported by bedrock/anthropic atm call throws error)
215215
stream_options={"include_usage": True}
216216
)
217-
root_agent = ExtendedLlmAgent(
217+
root_agent = SmartFixLlmAgent(
218218
model=model_instance,
219219
name=agent_name,
220220
instruction=agent_instruction,

src/extensions/extended_litellm.py renamed to src/extensions/smartfix_litellm.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -
22
# #%L
3-
# Extended LiteLLM with Prompt Caching
3+
# SmartFix LiteLLM with Prompt Caching
44
# %%
55
# Copyright (C) 2025 Contrast Security, Inc.
66
# %%
@@ -58,7 +58,7 @@ def reset(self):
5858

5959
def add_usage(self, input_tokens: int, output_tokens: int, cache_read_tokens: int,
6060
cache_write_tokens: int, new_input_cost: float, cache_read_cost: float,
61-
cache_write_cost: float, output_cost: float):
61+
cache_write_cost: float, output_cost: float) -> None:
6262
"""Add usage statistics from a single LLM call."""
6363
# Accumulate tokens
6464
self.total_new_input_tokens += input_tokens
@@ -114,8 +114,8 @@ def cache_savings_percentage(self):
114114
return 0.0
115115

116116

117-
class ExtendedLiteLlm(LiteLlm):
118-
"""Extended LiteLlm with automatic prompt caching and comprehensive cost analysis.
117+
class SmartFixLiteLlm(LiteLlm):
118+
"""SmartFix LiteLlm with automatic prompt caching and comprehensive cost analysis.
119119
120120
This class extends the base LiteLlm to automatically apply prompt caching
121121
and provide detailed cost analysis for all LLM interactions:
@@ -131,16 +131,16 @@ class ExtendedLiteLlm(LiteLlm):
131131
Example usage:
132132
```python
133133
# Bedrock Claude - Works with all features and cost tracking
134-
model = ExtendedLiteLlm(model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0")
134+
model = SmartFixLiteLlm(model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0")
135135
136136
# Direct Anthropic - Works with all features and cost tracking
137-
model = ExtendedLiteLlm(model="anthropic/claude-3-7-sonnet-20250219")
137+
model = SmartFixLiteLlm(model="anthropic/claude-3-7-sonnet-20250219")
138138
139139
# OpenAI - works with cost tracking (no caching applied)
140-
model = ExtendedLiteLlm(model="openai/gpt-4o")
140+
model = SmartFixLiteLlm(model="openai/gpt-4o")
141141
142142
# Other models - work normally with cost tracking
143-
model = ExtendedLiteLlm(model="gemini/gemini-1.5-pro")
143+
model = SmartFixLiteLlm(model="gemini/gemini-1.5-pro")
144144
```
145145
"""
146146

src/extensions/extended_llm_agent.py renamed to src/extensions/smartfix_llm_agent.py

Lines changed: 38 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# -
22
# #%L
3-
# Extended LLM Agent
3+
# SmartFix LLM Agent
44
# %%
55
# Copyright (C) 2025 Contrast Security, Inc.
66
# %%
@@ -31,32 +31,32 @@
3131
from google.adk.agents import LlmAgent
3232
from src.utils import debug_log
3333

34-
# Import ExtendedLiteLlm at module level so it's available for model_rebuild()
35-
ExtendedLiteLlm = None
34+
# Import SmartFixLiteLlm at module level so it's available for model_rebuild()
35+
SmartFixLiteLlm = None
3636
try:
37-
from .extended_litellm import ExtendedLiteLlm
37+
from .smartfix_litellm import SmartFixLiteLlm
3838
except ImportError:
3939
pass
4040

4141

42-
class ExtendedLlmAgent(LlmAgent):
43-
"""Extended LLM Agent that preserves ExtendedLiteLlm statistics across calls.
42+
class SmartFixLlmAgent(LlmAgent):
43+
"""SmartFix LLM Agent that preserves SmartFixLiteLlm statistics across calls.
4444
45-
This class solves the issue where ExtendedLiteLlm accumulated statistics
45+
This class solves the issue where SmartFixLiteLlm accumulated statistics
4646
aren't preserved when using the model within a Google ADK Agent. It ensures
47-
that the same ExtendedLiteLlm instance is used for all LLM calls and provides
47+
that the same SmartFixLiteLlm instance is used for all LLM calls and provides
4848
convenient methods to access accumulated statistics.
4949
5050
Example usage:
5151
```python
52-
from src.extensions.extended_litellm import ExtendedLiteLlm
53-
from src.extensions.extended_llm_agent import ExtendedLlmAgent
52+
from src.extensions.smartfix_litellm import SmartFixLiteLlm
53+
from src.extensions.smartfix_llm_agent import SmartFixLlmAgent
5454
5555
# Create the extended model
56-
model = ExtendedLiteLlm(model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0")
56+
model = SmartFixLiteLlm(model="bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0")
5757
5858
# Create the extended agent
59-
agent = ExtendedLlmAgent(
59+
agent = SmartFixLlmAgent(
6060
name="my-agent",
6161
model=model,
6262
instruction="You are a helpful assistant."
@@ -77,48 +77,48 @@ class ExtendedLlmAgent(LlmAgent):
7777
original_extended_model: Optional[Any] = Field(
7878
default=None,
7979
exclude=True,
80-
description="Reference to the original ExtendedLiteLlm instance for stats access"
80+
description="Reference to the original SmartFixLiteLlm instance for stats access"
8181
)
8282

8383
@model_validator(mode='after')
8484
def _preserve_extended_model_reference(self):
85-
"""Preserve a reference to the ExtendedLiteLlm instance if provided."""
85+
"""Preserve a reference to the SmartFixLiteLlm instance if provided."""
8686
# Import here to avoid circular imports
8787
try:
88-
from .extended_litellm import ExtendedLiteLlm
88+
from .smartfix_litellm import SmartFixLiteLlm
8989

90-
if isinstance(self.model, ExtendedLiteLlm):
90+
if isinstance(self.model, SmartFixLiteLlm):
9191
# Store reference to the original instance
9292
self.original_extended_model = self.model
93-
debug_log(f"[EXTENDED_AGENT] Preserved reference to ExtendedLiteLlm instance for agent: {self.name}")
93+
debug_log(f"[SMARTFIX_AGENT] Preserved reference to SmartFixLiteLlm instance for agent: {self.name}")
9494

9595
except ImportError:
96-
# ExtendedLiteLlm not available, ignore
96+
# SmartFixLiteLlm not available, ignore
9797
pass
9898

9999
return self
100100

101101
@override
102102
@property
103103
def canonical_model(self):
104-
"""Override to ensure we return the original ExtendedLiteLlm instance."""
105-
# If we have a preserved ExtendedLiteLlm instance, return it
104+
"""Override to ensure we return the original SmartFixLiteLlm instance."""
105+
# If we have a preserved SmartFixLiteLlm instance, return it
106106
if self.original_extended_model is not None:
107107
return self.original_extended_model
108108

109109
# Otherwise, use the parent's canonical_model
110110
return super().canonical_model
111111

112112
def has_extended_model(self) -> bool:
113-
"""Check if this agent is using an ExtendedLiteLlm model."""
113+
"""Check if this agent is using an SmartFixLiteLlm model."""
114114
try:
115-
from .extended_litellm import ExtendedLiteLlm
116-
return isinstance(self.canonical_model, ExtendedLiteLlm)
115+
from .smartfix_litellm import SmartFixLiteLlm
116+
return isinstance(self.canonical_model, SmartFixLiteLlm)
117117
except ImportError:
118118
return False
119119

120120
def get_extended_model(self) -> Optional[Any]:
121-
"""Get the ExtendedLiteLlm instance if available."""
121+
"""Get the SmartFixLiteLlm instance if available."""
122122
if self.has_extended_model():
123123
return self.canonical_model
124124
return None
@@ -127,18 +127,18 @@ def gather_accumulated_stats_dict(self) -> dict:
127127
"""Get accumulated token usage and cost statistics as dictionary.
128128
129129
This method provides programmatic access to the accumulated statistics
130-
from the ExtendedLiteLlm instance being used by this agent.
130+
from the SmartFixLiteLlm instance being used by this agent.
131131
132132
Returns:
133133
dict: Dictionary containing accumulated statistics
134134
135135
Raises:
136-
ValueError: If the agent is not using an ExtendedLiteLlm model.
136+
ValueError: If the agent is not using an SmartFixLiteLlm model.
137137
"""
138138
extended_model = self.get_extended_model()
139139
if extended_model is None:
140140
raise ValueError(
141-
f"Agent '{self.name}' is not using an ExtendedLiteLlm model. "
141+
f"Agent '{self.name}' is not using an SmartFixLiteLlm model. "
142142
"Cannot access accumulated statistics. "
143143
f"Current model type: {type(self.canonical_model).__name__}"
144144
)
@@ -149,18 +149,18 @@ def gather_accumulated_stats(self) -> str:
149149
"""Get accumulated token usage and cost statistics as JSON string.
150150
151151
This method provides programmatic access to the accumulated statistics
152-
from the ExtendedLiteLlm instance being used by this agent.
152+
from the SmartFixLiteLlm instance being used by this agent.
153153
154154
Returns:
155155
str: JSON formatted string containing accumulated statistics
156156
157157
Raises:
158-
ValueError: If the agent is not using an ExtendedLiteLlm model.
158+
ValueError: If the agent is not using an SmartFixLiteLlm model.
159159
"""
160160
extended_model = self.get_extended_model()
161161
if extended_model is None:
162162
raise ValueError(
163-
f"Agent '{self.name}' is not using an ExtendedLiteLlm model. "
163+
f"Agent '{self.name}' is not using an SmartFixLiteLlm model. "
164164
"Cannot access accumulated statistics. "
165165
f"Current model type: {type(self.canonical_model).__name__}"
166166
)
@@ -171,12 +171,12 @@ def reset_accumulated_stats(self) -> None:
171171
"""Reset accumulated statistics to start fresh.
172172
173173
Raises:
174-
ValueError: If the agent is not using an ExtendedLiteLlm model.
174+
ValueError: If the agent is not using an SmartFixLiteLlm model.
175175
"""
176176
extended_model = self.get_extended_model()
177177
if extended_model is None:
178178
raise ValueError(
179-
f"Agent '{self.name}' is not using an ExtendedLiteLlm model. "
179+
f"Agent '{self.name}' is not using an SmartFixLiteLlm model. "
180180
"Cannot reset accumulated statistics. "
181181
f"Current model type: {type(self.canonical_model).__name__}"
182182
)
@@ -197,12 +197,12 @@ def get_accumulated_stats_summary(self) -> dict:
197197
- And more detailed breakdowns
198198
199199
Raises:
200-
ValueError: If the agent is not using an ExtendedLiteLlm model.
200+
ValueError: If the agent is not using an SmartFixLiteLlm model.
201201
"""
202202
extended_model = self.get_extended_model()
203203
if extended_model is None:
204204
raise ValueError(
205-
f"Agent '{self.name}' is not using an ExtendedLiteLlm model. "
205+
f"Agent '{self.name}' is not using an SmartFixLiteLlm model. "
206206
"Cannot access accumulated statistics. "
207207
f"Current model type: {type(self.canonical_model).__name__}"
208208
)
@@ -236,7 +236,7 @@ def get_model_info(self) -> dict:
236236
dict: Dictionary containing model information including:
237237
- model_name: The model name/identifier
238238
- model_type: The class name of the model
239-
- is_extended: Whether it's an ExtendedLiteLlm instance
239+
- is_extended: Whether it's an SmartFixLiteLlm instance
240240
- has_stats: Whether accumulated statistics are available
241241
"""
242242
model = self.canonical_model
@@ -251,11 +251,11 @@ def get_model_info(self) -> dict:
251251
}
252252

253253

254-
# Rebuild the model schema after ExtendedLiteLlm is available
254+
# Rebuild the model schema after SmartFixLiteLlm is available
255255
# This resolves forward references and ensures Pydantic can fully validate the model
256-
if ExtendedLiteLlm is not None:
256+
if SmartFixLiteLlm is not None:
257257
try:
258-
ExtendedLlmAgent.model_rebuild()
258+
SmartFixLlmAgent.model_rebuild()
259259
except Exception:
260260
# If rebuild fails for any reason, just continue
261261
# The class will still work, just without perfect type validation

test/test_extended_litellm.py renamed to test/test_smartfix_litellm.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
#
2020

2121
"""
22-
Unit tests for ExtendedLiteLlm and TokenCostAccumulator classes.
22+
Unit tests for SmartFixLiteLlm and TokenCostAccumulator classes.
2323
2424
This module tests the extended LiteLLM functionality including:
2525
- Token cost accumulation and statistics gathering
@@ -44,7 +44,7 @@
4444
_ = get_config(testing=True)
4545

4646
# Import the classes under test AFTER config initialization
47-
from src.extensions.extended_litellm import ExtendedLiteLlm, TokenCostAccumulator # noqa: E402
47+
from src.extensions.smartfix_litellm import SmartFixLiteLlm, TokenCostAccumulator # noqa: E402
4848

4949

5050
class TestTokenCostAccumulator(unittest.TestCase):
@@ -262,21 +262,21 @@ def test_reset(self):
262262
self.assertEqual(self.accumulator.call_count, 0)
263263

264264

265-
class TestExtendedLiteLlm(unittest.TestCase):
266-
"""Test cases for ExtendedLiteLlm class."""
265+
class TestSmartFixLiteLlm(unittest.TestCase):
266+
"""Test cases for SmartFixLiteLlm class."""
267267

268268
def setUp(self):
269269
"""Set up test fixtures before each test method."""
270270
# Mock LiteLlm initialization to avoid dependencies
271271
with patch('litellm.completion'):
272-
self.extended_model = ExtendedLiteLlm(model="test-model")
272+
self.extended_model = SmartFixLiteLlm(model="test-model")
273273

274274
def test_initialization(self):
275-
"""Test that ExtendedLiteLlm initializes correctly."""
275+
"""Test that SmartFixLiteLlm initializes correctly."""
276276
self.assertEqual(self.extended_model.model, "test-model")
277277
self.assertIsInstance(self.extended_model.cost_accumulator, TokenCostAccumulator)
278278

279-
@patch('src.extensions.extended_litellm.debug_log')
279+
@patch('src.extensions.smartfix_litellm.debug_log')
280280
def test_gather_accumulated_stats_dict(self, mock_debug_log):
281281
"""Test statistics dictionary generation."""
282282
# Add some usage to the accumulator
@@ -299,7 +299,7 @@ def test_gather_accumulated_stats_dict(self, mock_debug_log):
299299
self.assertIn('cost_analysis', stats)
300300
self.assertIn('averages', stats)
301301

302-
@patch('src.extensions.extended_litellm.debug_log')
302+
@patch('src.extensions.smartfix_litellm.debug_log')
303303
def test_gather_accumulated_stats_json(self, mock_debug_log):
304304
"""Test JSON statistics generation."""
305305
# Add some usage to the accumulator
@@ -321,7 +321,7 @@ def test_gather_accumulated_stats_json(self, mock_debug_log):
321321
self.assertEqual(stats_dict['call_count'], 1)
322322
self.assertIn('token_usage', stats_dict)
323323

324-
@patch('src.extensions.extended_litellm.debug_log')
324+
@patch('src.extensions.smartfix_litellm.debug_log')
325325
def test_reset_accumulated_stats(self, mock_debug_log):
326326
"""Test that reset clears accumulated statistics."""
327327
# Add some usage first
@@ -349,15 +349,15 @@ def test_reset_accumulated_stats(self, mock_debug_log):
349349
mock_debug_log.assert_called_with("Accumulated statistics have been reset.")
350350

351351

352-
class TestExtendedLiteLlmIntegration(unittest.TestCase):
353-
"""Integration tests for ExtendedLiteLlm functionality."""
352+
class TestSmartFixLiteLlmIntegration(unittest.TestCase):
353+
"""Integration tests for SmartFixLiteLlm functionality."""
354354

355355
@patch('litellm.completion')
356-
@patch('src.extensions.extended_litellm.debug_log')
356+
@patch('src.extensions.smartfix_litellm.debug_log')
357357
def test_cost_accumulator_integration(self, mock_debug_log, mock_completion):
358-
"""Test that cost accumulator integrates properly with ExtendedLiteLlm."""
359-
# Create a real ExtendedLiteLlm instance
360-
model = ExtendedLiteLlm(model="test-integration-model")
358+
"""Test that cost accumulator integrates properly with SmartFixLiteLlm."""
359+
# Create a real SmartFixLiteLlm instance
360+
model = SmartFixLiteLlm(model="test-integration-model")
361361

362362
# Verify it has a cost accumulator
363363
self.assertIsInstance(model.cost_accumulator, TokenCostAccumulator)

0 commit comments

Comments
 (0)