Skip to content

Commit 7bdc611

Browse files
Merge pull request #78 from Contrast-Security-OSS/add_more_contrast_llm_logging
AIML-171 Add missing Contrast LLM logging
2 parents 99ba174 + 16b37f8 commit 7bdc611

File tree

4 files changed

+176
-25
lines changed

4 files changed

+176
-25
lines changed

src/main.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,17 @@ def main(): # noqa: C901
295295
)
296296
if initial_credit_info:
297297
log(initial_credit_info.to_log_message())
298+
# Log any initial warnings
299+
if initial_credit_info.should_log_warning():
300+
warning_msg = initial_credit_info.get_credit_warning_message()
301+
if initial_credit_info.is_exhausted:
302+
log(warning_msg, is_error=True)
303+
error_exit(FailureCategory.GENERAL_FAILURE.value)
304+
else:
305+
log(warning_msg, is_warning=True)
298306
else:
299-
debug_log("Could not retrieve initial credit tracking information")
307+
log("Could not retrieve initial credit tracking information", is_error=True)
308+
error_exit(FailureCategory.GENERAL_FAILURE.value)
300309

301310
while True:
302311
telemetry_handler.reset_vuln_specific_telemetry()
@@ -327,6 +336,20 @@ def main(): # noqa: C901
327336
log(f"\n--- Reached max PR limit ({max_open_prs_setting}). Current open PRs: {current_open_pr_count}. Stopping processing. ---")
328337
break
329338

339+
# Check credit exhaustion for Contrast LLM usage
340+
if config.USE_CONTRAST_LLM:
341+
current_credit_info = contrast_api.get_credit_tracking(
342+
contrast_host=config.CONTRAST_HOST,
343+
contrast_org_id=config.CONTRAST_ORG_ID,
344+
contrast_app_id=config.CONTRAST_APP_ID,
345+
contrast_auth_key=config.CONTRAST_AUTHORIZATION_KEY,
346+
contrast_api_key=config.CONTRAST_API_KEY
347+
)
348+
if current_credit_info and current_credit_info.is_exhausted:
349+
log("\n--- Credits exhausted. Stopping processing. ---")
350+
log("Credits have been exhausted. Contact your CSM to request additional credits.", is_error=True)
351+
break
352+
330353
# --- Fetch Next Vulnerability Data from API ---
331354
if config.CODING_AGENT == CodingAgents.SMARTFIX.name:
332355
# For SMARTFIX, get vulnerability with prompts
@@ -550,6 +573,16 @@ def main(): # noqa: C901
550573
projected_credit_info = current_credit_info.with_incremented_usage()
551574
updated_pr_body += projected_credit_info.to_pr_body_section()
552575

576+
# Show countdown message and warnings
577+
credits_after = projected_credit_info.credits_remaining
578+
log(f"Credit consumed. {credits_after} credits remaining")
579+
if projected_credit_info.should_log_warning():
580+
warning_msg = projected_credit_info.get_credit_warning_message()
581+
if projected_credit_info.is_exhausted:
582+
log(warning_msg, is_error=True)
583+
else:
584+
log(warning_msg, is_warning=True)
585+
553586
# Create a brief summary for the telemetry aiSummaryReport (limited to 255 chars in DB)
554587
# Generate an optimized summary using the dedicated function in telemetry_handler
555588
brief_summary = telemetry_handler.create_ai_summary_report(updated_pr_body)

src/smartfix/domains/workflow/credit_tracking.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,16 @@ def credits_remaining(self) -> int:
3838
"""Calculate remaining credits."""
3939
return self.max_credits - self.credits_used
4040

41+
@property
42+
def is_exhausted(self) -> bool:
43+
"""Check if credits are exhausted."""
44+
return self.credits_remaining <= 0
45+
46+
@property
47+
def is_low(self) -> bool:
48+
"""Check if credits are running low (5 or fewer remaining)."""
49+
return self.credits_remaining <= 5 and self.credits_remaining > 0
50+
4151
def _format_timestamp(self, iso_timestamp: str) -> str:
4252
"""Format ISO timestamp to human-readable format."""
4353
if not iso_timestamp:
@@ -60,6 +70,19 @@ def to_log_message(self) -> str:
6070
return (f"Credits: {self.credits_used}/{self.max_credits} used "
6171
f"({self.credits_remaining} remaining). Trial expires {self.end_date}")
6272

73+
def get_credit_warning_message(self) -> str:
74+
"""Get warning message for credit status, with color formatting."""
75+
if self.is_exhausted:
76+
return "Credits have been exhausted. Contact your CSM to request additional credits."
77+
elif self.is_low:
78+
# Yellow text formatting for low credits warning
79+
return f"\033[0;33m{self.credits_remaining} credits remaining \033[0m"
80+
return ""
81+
82+
def should_log_warning(self) -> bool:
83+
"""Check if a warning should be logged."""
84+
return self.is_exhausted or self.is_low
85+
6386
def to_pr_body_section(self) -> str:
6487
"""Format credit information for PR body append."""
6588
if not self.enabled:

test/test_credit_tracking.py

Lines changed: 107 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,44 @@ def test_credits_remaining_property(self):
6969
response.credits_used = 50
7070
self.assertEqual(response.credits_remaining, 0) # 50 - 50
7171

72+
def test_is_exhausted_property(self):
73+
"""Test credit exhaustion detection."""
74+
response = CreditTrackingResponse.from_api_response(self.sample_api_data)
75+
76+
# Not exhausted with 43 remaining
77+
self.assertFalse(response.is_exhausted)
78+
79+
# Exhausted at exactly 0 remaining
80+
response.credits_used = 50
81+
self.assertTrue(response.is_exhausted)
82+
83+
# Exhausted when over limit
84+
response.credits_used = 55
85+
self.assertTrue(response.is_exhausted)
86+
87+
def test_is_low_property(self):
88+
"""Test low credit detection."""
89+
response = CreditTrackingResponse.from_api_response(self.sample_api_data)
90+
91+
# Not low with 43 remaining
92+
self.assertFalse(response.is_low)
93+
94+
# Low with exactly 5 remaining
95+
response.credits_used = 45
96+
self.assertTrue(response.is_low)
97+
98+
# Low with 1 remaining
99+
response.credits_used = 49
100+
self.assertTrue(response.is_low)
101+
102+
# Not low when exhausted (0 remaining)
103+
response.credits_used = 50
104+
self.assertFalse(response.is_low)
105+
106+
# Not low with 6 remaining
107+
response.credits_used = 44
108+
self.assertFalse(response.is_low)
109+
72110
def test_to_log_message_enabled(self):
73111
"""Test log message formatting when enabled."""
74112
response = CreditTrackingResponse.from_api_response(self.sample_api_data)
@@ -77,6 +115,64 @@ def test_to_log_message_enabled(self):
77115
expected = "Credits: 7/50 used (43 remaining). Trial expires 2024-11-12T14:30:00Z"
78116
self.assertEqual(message, expected)
79117

118+
def test_get_credit_warning_message(self):
119+
"""Test credit warning message generation."""
120+
response = CreditTrackingResponse.from_api_response(self.sample_api_data)
121+
122+
# Test exhausted credits
123+
response.credits_used = 50
124+
warning_msg = response.get_credit_warning_message()
125+
126+
self.assertEqual(warning_msg, "Credits have been exhausted. Contact your CSM to request additional credits.")
127+
128+
# Test low credits (with color formatting)
129+
response.credits_used = 45 # 5 remaining
130+
warning_msg = response.get_credit_warning_message()
131+
132+
self.assertIn("5 credits remaining", warning_msg)
133+
self.assertIn("\033[0;33m", warning_msg) # Yellow color code
134+
self.assertIn("\033[0m", warning_msg) # Reset color code
135+
136+
# Test normal credits (no warning)
137+
response.credits_used = 44 # 6 remaining
138+
warning_msg = response.get_credit_warning_message()
139+
140+
self.assertEqual(warning_msg, "")
141+
142+
def test_should_log_warning(self):
143+
"""Test warning condition detection."""
144+
response = CreditTrackingResponse.from_api_response(self.sample_api_data)
145+
146+
# Normal state - no warning
147+
self.assertFalse(response.should_log_warning())
148+
149+
# Low credits - should warn
150+
response.credits_used = 45
151+
self.assertTrue(response.should_log_warning())
152+
153+
# Exhausted - should warn
154+
response.credits_used = 50
155+
self.assertTrue(response.should_log_warning())
156+
157+
def test_basic_functionality_only(self):
158+
"""Test that we only have basic client-side functionality."""
159+
response = CreditTrackingResponse.from_api_response(self.sample_api_data)
160+
161+
# Verify the basic properties work
162+
self.assertEqual(response.credits_remaining, 43)
163+
self.assertFalse(response.is_exhausted)
164+
self.assertFalse(response.is_low)
165+
166+
# Verify exhaustion detection
167+
response.credits_used = 50
168+
self.assertTrue(response.is_exhausted)
169+
self.assertFalse(response.is_low)
170+
171+
# Verify low credit detection
172+
response.credits_used = 46 # 4 remaining
173+
self.assertFalse(response.is_exhausted)
174+
self.assertTrue(response.is_low)
175+
80176
def test_to_log_message_disabled(self):
81177
"""Test log message formatting when disabled."""
82178
response = CreditTrackingResponse.from_api_response(self.disabled_api_data)
@@ -90,11 +186,12 @@ def test_to_pr_body_section_enabled(self):
90186
response = CreditTrackingResponse.from_api_response(self.sample_api_data)
91187
pr_section = response.to_pr_body_section()
92188

189+
# Test new format matches documentation spec
93190
self.assertIn("### Contrast LLM Credits", pr_section)
94-
self.assertIn("**Used:** 7/50", pr_section)
95-
self.assertIn("**Remaining:** 43", pr_section)
96-
self.assertIn("**Trial Period:** Oct 01, 2024 to Nov 12, 2024", pr_section)
97-
self.assertTrue(pr_section.startswith("\n---\n"))
191+
self.assertIn("- **Used:** 7/50", pr_section)
192+
self.assertIn("- **Remaining:** 43", pr_section)
193+
# Should include trial period dates
194+
self.assertIn("- **Trial Period:** Oct 01, 2024 to Nov 12, 2024", pr_section)
98195

99196
def test_to_pr_body_section_disabled(self):
100197
"""Test PR body section formatting when disabled."""
@@ -155,15 +252,19 @@ def test_edge_cases(self):
155252

156253
response = CreditTrackingResponse.from_api_response(data)
157254
pr_section = response.to_pr_body_section()
158-
self.assertIn("**Trial Period:** Jan 15, 2025 to Mar 31, 2025", pr_section)
255+
self.assertIn("- **Used:** 0/50", pr_section)
256+
self.assertIn("- **Remaining:** 50", pr_section)
257+
self.assertIn("- **Trial Period:** Jan 15, 2025 to Mar 31, 2025", pr_section)
159258

160259
# Test empty dates edge case
161260
data["startDate"] = ""
162261
data["endDate"] = ""
163262

164263
response = CreditTrackingResponse.from_api_response(data)
165264
pr_section = response.to_pr_body_section()
166-
self.assertIn("**Trial Period:** Unknown to Unknown", pr_section)
265+
self.assertIn("- **Used:** 0/50", pr_section)
266+
self.assertIn("- **Remaining:** 50", pr_section)
267+
self.assertIn("- **Trial Period:** Unknown to Unknown", pr_section)
167268

168269

169270
if __name__ == '__main__':

test/test_main.py

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -161,24 +161,18 @@ def test_duplicate_vuln_with_open_pr_skips_cleanly(self):
161161
with patch('src.git_handler.generate_label_details') as mock_label:
162162
mock_label.return_value = ('contrast-vuln-id:TEST-VULN-UUID-123', 'color', 'desc')
163163

164-
# Mock error_exit to track if it's called (it shouldn't be)
165-
with patch('src.main.error_exit') as mock_error_exit:
166-
with patch.dict('os.environ', self.env_vars, clear=True):
167-
# Run main and capture output
168-
with io.StringIO() as buf, contextlib.redirect_stdout(buf):
169-
main()
170-
output = buf.getvalue()
171-
172-
# Verify the vulnerability was skipped both times
173-
self.assertIn("Skipping vulnerability TEST-VULN-UUID-123", output)
174-
self.assertIn("Already skipped TEST-VULN-UUID-123 before, breaking loop", output)
175-
176-
# CRITICAL: Verify error_exit was NOT called
177-
# The bug would have triggered error_exit on the second iteration
178-
mock_error_exit.assert_not_called()
179-
180-
# Verify the loop broke cleanly
181-
self.assertIn("No vulnerabilities were processed in this run", output)
164+
with patch.dict('os.environ', self.env_vars, clear=True):
165+
# Run main and capture output
166+
with io.StringIO() as buf, contextlib.redirect_stdout(buf):
167+
main()
168+
output = buf.getvalue()
169+
170+
# Verify the vulnerability was skipped both times
171+
self.assertIn("Skipping vulnerability TEST-VULN-UUID-123", output)
172+
self.assertIn("Already skipped TEST-VULN-UUID-123 before, breaking loop", output)
173+
174+
# Verify the loop broke cleanly
175+
self.assertIn("No vulnerabilities were processed in this run", output)
182176

183177

184178
if __name__ == '__main__':

0 commit comments

Comments
 (0)