Skip to content

Commit dc02da4

Browse files
Merge pull request #36 from Contrast-Security-OSS/AIML-51_update_merge_handler_for_copilot
AIML-51 Update merge handler for copilot
2 parents dc172c8 + 131b71c commit dc02da4

File tree

7 files changed

+111
-13
lines changed

7 files changed

+111
-13
lines changed

src/closed_handler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from src import contrast_api
2626
from src.config import get_config # Using get_config function instead of direct import
2727
from src.utils import debug_log, extract_remediation_id_from_branch, extract_remediation_id_from_labels, log
28+
from src.git_handler import extract_issue_number_from_branch
2829
import src.telemetry_handler as telemetry_handler
2930

3031
def handle_closed_pr():
@@ -65,18 +66,27 @@ def handle_closed_pr():
6566

6667
debug_log(f"Branch name: {branch_name}")
6768

69+
labels = pull_request.get("labels", [])
70+
6871
# Extract remediation ID from branch name or PR labels
6972
remediation_id = None
7073

7174
# Check if this is a branch created by external agent (e.g., GitHub Copilot)
7275
if branch_name.startswith("copilot/fix"):
7376
debug_log("Branch appears to be created by external agent. Extracting remediation ID from PR labels.")
74-
# Get labels from the PR
75-
labels = pull_request.get("labels", [])
7677
remediation_id = extract_remediation_id_from_labels(labels)
78+
# Extract GitHub issue number from branch name
79+
issue_number = extract_issue_number_from_branch(branch_name)
80+
if issue_number:
81+
telemetry_handler.update_telemetry("additionalAttributes.externalIssueNumber", issue_number)
82+
debug_log(f"Extracted external issue number from branch name: {issue_number}")
83+
else:
84+
debug_log(f"Could not extract issue number from branch name: {branch_name}")
85+
telemetry_handler.update_telemetry("additionalAttributes.codingAgent", "EXTERNAL-COPILOT")
7786
else:
7887
# Use original method for branches created by SmartFix
7988
remediation_id = extract_remediation_id_from_branch(branch_name)
89+
telemetry_handler.update_telemetry("additionalAttributes.codingAgent", "INTERNAL-SMARTFIX")
8090

8191
if not remediation_id:
8292
if branch_name.startswith("copilot/fix"):
@@ -90,7 +100,6 @@ def handle_closed_pr():
90100
telemetry_handler.update_telemetry("additionalAttributes.remediationId", remediation_id)
91101

92102
# Try to extract vulnerability UUID from PR labels
93-
labels = pull_request.get("labels", [])
94103
vuln_uuid = "unknown"
95104

96105
for label in labels:

src/external_coding_agent.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def generate_fixes(self, vuln_uuid: str, remediation_id: str, vuln_title: str) -
5353
return False
5454

5555
log(f"\n::group::--- Using External Coding Agent ({self.config.CODING_AGENT}) ---")
56-
56+
telemetry_handler.update_telemetry("additionalAttributes.codingAgent", "EXTERNAL-COPILOT")
57+
5758
# Hard-coded vulnerability label for now, will be passed as argument later
5859
vulnerability_label = f"contrast-vuln-id:VULN-{vuln_uuid}"
5960
remediation_label = f"smartfix-id:{remediation_id}"
@@ -75,11 +76,10 @@ def generate_fixes(self, vuln_uuid: str, remediation_id: str, vuln_title: str) -
7576
log(f"Failed to create issue with labels {vulnerability_label}, {remediation_label}", is_error=True)
7677
error_exit(remediation_id, FailureCategory.AGENT_FAILURE.value)
7778

78-
telemetry_handler.update_telemetry("additionalAttributes.githubIssueNumber", issue_number)
79+
telemetry_handler.update_telemetry("additionalAttributes.externalIssueNumber", issue_number)
7980

8081
# Poll for PR creation by the external agent
8182
log(f"Waiting for external agent to create a PR for issue #{issue_number}")
82-
telemetry_handler.update_telemetry("additionalAttributes.codingAgent", "EXTERNAL")
8383

8484
# Poll for a PR to be created by the external agent (100 attempts, 5 seconds apart = ~8.3 minutes max)
8585
pr_info = self._poll_for_pr(issue_number, remediation_id, vulnerability_label, remediation_label, max_attempts=100, sleep_seconds=5)

src/git_handler.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
import os
2121
import json
2222
import subprocess
23-
from typing import List
23+
import re
24+
from typing import List, Optional
2425
from src.utils import run_command, debug_log, log, error_exit
2526
from src.contrast_api import FailureCategory
2627
from src.config import get_config
@@ -648,6 +649,37 @@ def find_open_pr_for_issue(issue_number: int) -> dict:
648649
log(f"Error searching for PRs related to issue #{issue_number}: {e}", is_error=True)
649650
return None
650651

652+
def extract_issue_number_from_branch(branch_name: str) -> Optional[int]:
653+
"""
654+
Extracts the GitHub issue number from a branch name with format 'copilot/fix-<issue_number>'.
655+
656+
Args:
657+
branch_name: The branch name to extract the issue number from
658+
659+
Returns:
660+
Optional[int]: The issue number if found and valid, None otherwise
661+
"""
662+
if not branch_name:
663+
return None
664+
665+
# Use regex to match the exact pattern: copilot/fix-<number>
666+
# This ensures we only match the expected format and extract just the number
667+
pattern = r'^copilot/fix-(\d+)$'
668+
match = re.match(pattern, branch_name)
669+
670+
if match:
671+
try:
672+
issue_number = int(match.group(1))
673+
# Validate that it's a positive number (GitHub issue numbers start from 1)
674+
if issue_number > 0:
675+
return issue_number
676+
except ValueError:
677+
# This shouldn't happen since \d+ only matches digits, but being safe
678+
debug_log(f"Failed to convert extracted issue number '{match.group(1)}' to int")
679+
pass
680+
681+
return None
682+
651683
def add_labels_to_pr(pr_number: int, labels: List[str]) -> bool:
652684
"""
653685
Add labels to an existing pull request.

src/main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ def main():
351351
contrast_api.send_telemetry_data()
352352
continue # Skip the built-in SmartFix code and PR creation
353353

354+
telemetry_handler.update_telemetry("additionalAttributes.codingAgent", "INTERNAL-SMARTFIX")
355+
354356
# --- Run AI Fix Agent (SmartFix) ---
355357
ai_fix_summary_full = agent_handler.run_ai_fix_agent(
356358
config.REPO_ROOT, fix_system_prompt, fix_user_prompt, remediation_id

src/merge_handler.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424
# Import from src package to ensure correct module resolution
2525
from src import contrast_api
2626
from src.config import get_config # Using get_config function instead of direct import
27-
from src.utils import debug_log, extract_remediation_id_from_branch, log
27+
from src.utils import debug_log, extract_remediation_id_from_branch, extract_remediation_id_from_labels, log
28+
from src.git_handler import extract_issue_number_from_branch
2829
import src.telemetry_handler as telemetry_handler
2930

3031
def handle_merged_pr():
@@ -65,18 +66,27 @@ def handle_merged_pr():
6566

6667
debug_log(f"Branch name: {branch_name}")
6768

69+
labels = pull_request.get("labels", [])
70+
6871
# Extract remediation ID from branch name or PR labels
6972
remediation_id = None
7073

7174
# Check if this is a branch created by external agent (e.g., GitHub Copilot)
7275
if branch_name.startswith("copilot/fix"):
7376
debug_log("Branch appears to be created by external agent. Extracting remediation ID from PR labels.")
74-
# Get labels from the PR
75-
labels = pull_request.get("labels", [])
7677
remediation_id = extract_remediation_id_from_labels(labels)
78+
# Extract GitHub issue number from branch name
79+
issue_number = extract_issue_number_from_branch(branch_name)
80+
if issue_number:
81+
telemetry_handler.update_telemetry("additionalAttributes.externalIssueNumber", issue_number)
82+
debug_log(f"Extracted external issue number from branch name: {issue_number}")
83+
else:
84+
debug_log(f"Could not extract issue number from branch name: {branch_name}")
85+
telemetry_handler.update_telemetry("additionalAttributes.codingAgent", "EXTERNAL-COPILOT")
7786
else:
7887
# Use original method for branches created by SmartFix
7988
remediation_id = extract_remediation_id_from_branch(branch_name)
89+
telemetry_handler.update_telemetry("additionalAttributes.codingAgent", "INTERNAL-SMARTFIX")
8090

8191
if not remediation_id:
8292
if branch_name.startswith("copilot/fix"):
@@ -90,7 +100,6 @@ def handle_merged_pr():
90100
telemetry_handler.update_telemetry("additionalAttributes.remediationId", remediation_id)
91101

92102
# Try to extract vulnerability UUID from PR labels
93-
labels = pull_request.get("labels", [])
94103
vuln_uuid = "unknown"
95104

96105
for label in labels:

test/test_external_coding_agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def test_generate_fixes_with_external_agent_pr_created(self, mock_log, mock_debu
135135
mock_log.assert_any_call(f"External agent created PR #123 at https://github.com/owner/repo/pull/123")
136136

137137
# Verify telemetry updates
138-
mock_update_telemetry.assert_any_call("additionalAttributes.codingAgent", "EXTERNAL")
138+
mock_update_telemetry.assert_any_call("additionalAttributes.codingAgent", "EXTERNAL-COPILOT")
139139
mock_update_telemetry.assert_any_call("resultInfo.prCreated", True)
140140
mock_update_telemetry.assert_any_call("additionalAttributes.prStatus", "OPEN")
141141
mock_update_telemetry.assert_any_call("additionalAttributes.prNumber", 123)
@@ -187,7 +187,7 @@ def test_generate_fixes_with_external_agent_pr_timeout(self, mock_log, mock_debu
187187
mock_log.assert_any_call("External agent failed to create a PR within the timeout period", is_error=True)
188188

189189
# Verify telemetry updates
190-
mock_update_telemetry.assert_any_call("additionalAttributes.codingAgent", "EXTERNAL")
190+
mock_update_telemetry.assert_any_call("additionalAttributes.codingAgent", "EXTERNAL-COPILOT")
191191
mock_update_telemetry.assert_any_call("resultInfo.prCreated", False)
192192
mock_update_telemetry.assert_any_call("resultInfo.failureReason", "PR creation timeout")
193193
mock_update_telemetry.assert_any_call("resultInfo.failureCategory", "AGENT_FAILURE")

test/test_git_handler.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,52 @@ def test_add_labels_to_pr_success(self, mock_debug_log, mock_log, mock_run_comma
388388
mock_log.assert_any_call("Adding labels to PR #123: ['contrast-vuln-id:VULN-12345', 'smartfix-id:remediation-67890']")
389389
mock_log.assert_any_call("Successfully added labels to PR #123: ['contrast-vuln-id:VULN-12345', 'smartfix-id:remediation-67890']")
390390

391+
def test_extract_issue_number_from_branch_success(self):
392+
"""Test extracting issue number from valid copilot branch name"""
393+
# Test cases with valid branch names
394+
test_cases = [
395+
("copilot/fix-123", 123),
396+
("copilot/fix-1", 1),
397+
("copilot/fix-999999", 999999),
398+
("copilot/fix-42", 42),
399+
]
400+
401+
for branch_name, expected_issue_number in test_cases:
402+
with self.subTest(branch_name=branch_name):
403+
result = git_handler.extract_issue_number_from_branch(branch_name)
404+
self.assertEqual(result, expected_issue_number)
405+
406+
def test_extract_issue_number_from_branch_invalid(self):
407+
"""Test extracting issue number from invalid branch names"""
408+
# Test cases with invalid branch names
409+
invalid_branches = [
410+
"main", # Wrong branch name
411+
"feature/new-feature", # Wrong branch name
412+
"copilot/fix-", # Missing issue number
413+
"copilot/fix-abc", # Non-numeric issue number
414+
"copilot/fix-123abc", # Invalid format
415+
"copilot/fix-123-extra", # Extra parts
416+
"smartfix/remediation-123", # Different prefix
417+
"", # Empty string
418+
]
419+
420+
for branch_name in invalid_branches:
421+
with self.subTest(branch_name=branch_name):
422+
result = git_handler.extract_issue_number_from_branch(branch_name)
423+
self.assertIsNone(result)
424+
425+
def test_extract_issue_number_from_branch_edge_cases(self):
426+
"""Test edge cases for extracting issue number from branch name"""
427+
# Test edge cases
428+
edge_cases = [
429+
("copilot/fix-2147483647", 2147483647), # Large number (max 32-bit int)
430+
]
431+
432+
for branch_name, expected_issue_number in edge_cases:
433+
with self.subTest(branch_name=branch_name):
434+
result = git_handler.extract_issue_number_from_branch(branch_name)
435+
self.assertEqual(result, expected_issue_number)
436+
391437
@patch('src.git_handler.ensure_label')
392438
@patch('src.git_handler.run_command')
393439
@patch('src.git_handler.log')

0 commit comments

Comments
 (0)