Skip to content

Commit 2edd07f

Browse files
slister1001Copilot
andauthored
[evaluation] refactor: red_team module clean up (#42292)
* init * updates * further refactoring * run black reformatter * fix tests * run black code formatter * Update sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_result_processor.py Co-authored-by: Copilot <[email protected]> * updates --------- Co-authored-by: Copilot <[email protected]>
1 parent 39c384b commit 2edd07f

File tree

12 files changed

+4022
-3267
lines changed

12 files changed

+4022
-3267
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_evaluation_processor.py

Lines changed: 365 additions & 0 deletions
Large diffs are not rendered by default.

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_mlflow_integration.py

Lines changed: 322 additions & 0 deletions
Large diffs are not rendered by default.

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_orchestrator_manager.py

Lines changed: 649 additions & 0 deletions
Large diffs are not rendered by default.

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_red_team.py

Lines changed: 656 additions & 3057 deletions
Large diffs are not rendered by default.

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_result_processor.py

Lines changed: 610 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,37 @@
11
# ---------------------------------------------------------
22
# Copyright (c) Microsoft Corporation. All rights reserved.
33
# ---------------------------------------------------------
4+
"""
5+
Utility modules for Red Team Agent.
6+
7+
This package provides centralized utilities for retry logic, file operations,
8+
progress tracking, and exception handling used across red team components.
9+
"""
10+
11+
from .retry_utils import RetryManager, create_standard_retry_manager, create_retry_decorator
12+
from .file_utils import FileManager, create_file_manager
13+
from .progress_utils import ProgressManager, create_progress_manager
14+
from .exception_utils import (
15+
ExceptionHandler,
16+
RedTeamError,
17+
ErrorCategory,
18+
ErrorSeverity,
19+
create_exception_handler,
20+
exception_context,
21+
)
22+
23+
__all__ = [
24+
"RetryManager",
25+
"create_standard_retry_manager",
26+
"create_retry_decorator",
27+
"FileManager",
28+
"create_file_manager",
29+
"ProgressManager",
30+
"create_progress_manager",
31+
"ExceptionHandler",
32+
"RedTeamError",
33+
"ErrorCategory",
34+
"ErrorSeverity",
35+
"create_exception_handler",
36+
"exception_context",
37+
]
Lines changed: 345 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,345 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
"""
5+
Exception handling utilities for Red Team Agent.
6+
7+
This module provides centralized exception handling, error categorization,
8+
and error reporting utilities for red team operations.
9+
"""
10+
11+
import logging
12+
import traceback
13+
import asyncio
14+
from typing import Optional, Any, Dict, Union
15+
from enum import Enum
16+
17+
18+
class ErrorCategory(Enum):
19+
"""Categories of errors that can occur during red team operations."""
20+
21+
NETWORK = "network"
22+
AUTHENTICATION = "authentication"
23+
CONFIGURATION = "configuration"
24+
DATA_PROCESSING = "data_processing"
25+
ORCHESTRATOR = "orchestrator"
26+
EVALUATION = "evaluation"
27+
FILE_IO = "file_io"
28+
TIMEOUT = "timeout"
29+
UNKNOWN = "unknown"
30+
31+
32+
class ErrorSeverity(Enum):
33+
"""Severity levels for errors."""
34+
35+
LOW = "low" # Warning level, operation can continue
36+
MEDIUM = "medium" # Error level, task failed but scan can continue
37+
HIGH = "high" # Critical error, scan should be aborted
38+
FATAL = "fatal" # Unrecoverable error
39+
40+
41+
class RedTeamError(Exception):
42+
"""Base exception for Red Team operations."""
43+
44+
def __init__(
45+
self,
46+
message: str,
47+
category: ErrorCategory = ErrorCategory.UNKNOWN,
48+
severity: ErrorSeverity = ErrorSeverity.MEDIUM,
49+
context: Optional[Dict[str, Any]] = None,
50+
original_exception: Optional[Exception] = None,
51+
):
52+
super().__init__(message)
53+
self.message = message
54+
self.category = category
55+
self.severity = severity
56+
self.context = context or {}
57+
self.original_exception = original_exception
58+
59+
60+
class ExceptionHandler:
61+
"""Centralized exception handling for Red Team operations."""
62+
63+
def __init__(self, logger: Optional[logging.Logger] = None):
64+
"""Initialize exception handler.
65+
66+
:param logger: Logger instance for error reporting
67+
"""
68+
self.logger = logger or logging.getLogger(__name__)
69+
self.error_counts: Dict[ErrorCategory, int] = {category: 0 for category in ErrorCategory}
70+
71+
def categorize_exception(self, exception: Exception) -> ErrorCategory:
72+
"""Categorize an exception based on its type and message.
73+
74+
:param exception: The exception to categorize
75+
:return: The appropriate error category
76+
"""
77+
import httpx
78+
import httpcore
79+
80+
# Network-related errors
81+
network_exceptions = (
82+
httpx.ConnectTimeout,
83+
httpx.ReadTimeout,
84+
httpx.ConnectError,
85+
httpx.HTTPError,
86+
httpx.TimeoutException,
87+
httpcore.ReadTimeout,
88+
ConnectionError,
89+
ConnectionRefusedError,
90+
ConnectionResetError,
91+
)
92+
93+
if isinstance(exception, network_exceptions):
94+
return ErrorCategory.NETWORK
95+
96+
# Timeout errors (separate from network to handle asyncio.TimeoutError)
97+
if isinstance(exception, (TimeoutError, asyncio.TimeoutError)):
98+
return ErrorCategory.TIMEOUT
99+
100+
# File I/O errors
101+
if isinstance(exception, (IOError, OSError, FileNotFoundError, PermissionError)):
102+
return ErrorCategory.FILE_IO
103+
104+
# HTTP status code specific errors
105+
if hasattr(exception, "response") and hasattr(exception.response, "status_code"):
106+
status_code = exception.response.status_code
107+
if 500 <= status_code < 600:
108+
return ErrorCategory.NETWORK
109+
elif status_code == 401:
110+
return ErrorCategory.AUTHENTICATION
111+
elif status_code == 403:
112+
return ErrorCategory.CONFIGURATION
113+
114+
# String-based categorization
115+
message = str(exception).lower()
116+
117+
# Define keyword mappings for cleaner logic
118+
keyword_mappings = {
119+
ErrorCategory.AUTHENTICATION: ["authentication", "unauthorized"],
120+
ErrorCategory.CONFIGURATION: ["configuration", "config"],
121+
ErrorCategory.ORCHESTRATOR: ["orchestrator"],
122+
ErrorCategory.EVALUATION: ["evaluation", "evaluate", "model_error"],
123+
ErrorCategory.DATA_PROCESSING: ["data", "json"],
124+
}
125+
126+
for category, keywords in keyword_mappings.items():
127+
if any(keyword in message for keyword in keywords):
128+
return category
129+
130+
return ErrorCategory.UNKNOWN
131+
132+
def determine_severity(
133+
self, exception: Exception, category: ErrorCategory, context: Optional[Dict[str, Any]] = None
134+
) -> ErrorSeverity:
135+
"""Determine the severity of an exception.
136+
137+
:param exception: The exception to evaluate
138+
:param category: The error category
139+
:param context: Additional context for severity determination
140+
:return: The appropriate error severity
141+
"""
142+
context = context or {}
143+
144+
# Critical system errors
145+
if isinstance(exception, (MemoryError, SystemExit, KeyboardInterrupt)):
146+
return ErrorSeverity.FATAL
147+
148+
# Authentication and configuration are typically high severity
149+
if category in (ErrorCategory.AUTHENTICATION, ErrorCategory.CONFIGURATION):
150+
return ErrorSeverity.HIGH
151+
152+
# File I/O errors can be high severity if they involve critical files
153+
if category == ErrorCategory.FILE_IO:
154+
if context.get("critical_file", False):
155+
return ErrorSeverity.HIGH
156+
return ErrorSeverity.MEDIUM
157+
158+
# Network and timeout errors are usually medium severity (retryable)
159+
if category in (ErrorCategory.NETWORK, ErrorCategory.TIMEOUT):
160+
return ErrorSeverity.MEDIUM
161+
162+
# Task-specific errors are medium severity
163+
if category in (ErrorCategory.ORCHESTRATOR, ErrorCategory.EVALUATION, ErrorCategory.DATA_PROCESSING):
164+
return ErrorSeverity.MEDIUM
165+
166+
return ErrorSeverity.LOW
167+
168+
def handle_exception(
169+
self,
170+
exception: Exception,
171+
context: Optional[Dict[str, Any]] = None,
172+
task_name: Optional[str] = None,
173+
reraise: bool = False,
174+
) -> RedTeamError:
175+
"""Handle an exception with proper categorization and logging.
176+
177+
:param exception: The exception to handle
178+
:param context: Additional context information
179+
:param task_name: Name of the task where the exception occurred
180+
:param reraise: Whether to reraise the exception after handling
181+
:return: A RedTeamError with categorized information
182+
"""
183+
context = context or {}
184+
185+
# If it's already a RedTeamError, just log and return/reraise
186+
if isinstance(exception, RedTeamError):
187+
self._log_error(exception, task_name)
188+
if reraise:
189+
raise exception
190+
return exception
191+
192+
# Categorize the exception
193+
category = self.categorize_exception(exception)
194+
severity = self.determine_severity(exception, category, context)
195+
196+
# Update error counts
197+
self.error_counts[category] += 1
198+
199+
# Create RedTeamError
200+
message = f"{category.value.title()} error"
201+
if task_name:
202+
message += f" in {task_name}"
203+
message += f": {str(exception)}"
204+
205+
red_team_error = RedTeamError(
206+
message=message, category=category, severity=severity, context=context, original_exception=exception
207+
)
208+
209+
# Log the error
210+
self._log_error(red_team_error, task_name)
211+
212+
if reraise:
213+
raise red_team_error
214+
215+
return red_team_error
216+
217+
def _log_error(self, error: RedTeamError, task_name: Optional[str] = None) -> None:
218+
"""Log an error with appropriate level based on severity.
219+
220+
:param error: The RedTeamError to log
221+
:param task_name: Optional task name for context
222+
"""
223+
# Determine log level based on severity
224+
if error.severity == ErrorSeverity.FATAL:
225+
log_level = logging.CRITICAL
226+
elif error.severity == ErrorSeverity.HIGH:
227+
log_level = logging.ERROR
228+
elif error.severity == ErrorSeverity.MEDIUM:
229+
log_level = logging.WARNING
230+
else:
231+
log_level = logging.INFO
232+
233+
# Create log message
234+
message_parts = []
235+
if task_name:
236+
message_parts.append(f"[{task_name}]")
237+
message_parts.append(f"[{error.category.value}]")
238+
message_parts.append(f"[{error.severity.value}]")
239+
message_parts.append(error.message)
240+
241+
log_message = " ".join(message_parts)
242+
243+
# Log with appropriate level
244+
self.logger.log(log_level, log_message)
245+
246+
# Log additional context if available
247+
if error.context:
248+
self.logger.debug(f"Error context: {error.context}")
249+
250+
# Log original exception traceback for debugging
251+
if error.original_exception and self.logger.isEnabledFor(logging.DEBUG):
252+
self.logger.debug(f"Original exception traceback:\n{traceback.format_exc()}")
253+
254+
def should_abort_scan(self) -> bool:
255+
"""Determine if the scan should be aborted based on error patterns.
256+
257+
:return: True if the scan should be aborted
258+
"""
259+
# Abort if we have too many high-severity errors
260+
high_severity_categories = [ErrorCategory.AUTHENTICATION, ErrorCategory.CONFIGURATION]
261+
high_severity_count = sum(self.error_counts[cat] for cat in high_severity_categories)
262+
263+
if high_severity_count > 2:
264+
return True
265+
266+
# Abort if we have too many network errors (indicates systemic issue)
267+
if self.error_counts[ErrorCategory.NETWORK] > 10:
268+
return True
269+
270+
return False
271+
272+
def get_error_summary(self) -> Dict[str, Any]:
273+
"""Get a summary of all errors encountered.
274+
275+
:return: Dictionary containing error statistics
276+
"""
277+
total_errors = sum(self.error_counts.values())
278+
279+
return {
280+
"total_errors": total_errors,
281+
"error_counts_by_category": dict(self.error_counts),
282+
"most_common_category": max(self.error_counts, key=self.error_counts.get) if total_errors > 0 else None,
283+
"should_abort": self.should_abort_scan(),
284+
}
285+
286+
def log_error_summary(self) -> None:
287+
"""Log a summary of all errors encountered."""
288+
summary = self.get_error_summary()
289+
290+
if summary["total_errors"] == 0:
291+
self.logger.info("No errors encountered during operation")
292+
return
293+
294+
self.logger.info(f"Error Summary: {summary['total_errors']} total errors")
295+
296+
for category, count in summary["error_counts_by_category"].items():
297+
if count > 0:
298+
self.logger.info(f" {category}: {count}")
299+
300+
if summary["most_common_category"]:
301+
self.logger.info(f"Most common error type: {summary['most_common_category']}")
302+
303+
304+
def create_exception_handler(logger: Optional[logging.Logger] = None) -> ExceptionHandler:
305+
"""Create an ExceptionHandler instance.
306+
307+
:param logger: Logger instance for error reporting
308+
:return: Configured ExceptionHandler
309+
"""
310+
return ExceptionHandler(logger=logger)
311+
312+
313+
# Convenience context manager for handling exceptions
314+
class exception_context:
315+
"""Context manager for handling exceptions in Red Team operations."""
316+
317+
def __init__(
318+
self,
319+
handler: ExceptionHandler,
320+
task_name: str,
321+
context: Optional[Dict[str, Any]] = None,
322+
reraise_fatal: bool = True,
323+
):
324+
self.handler = handler
325+
self.task_name = task_name
326+
self.context = context or {}
327+
self.reraise_fatal = reraise_fatal
328+
self.error: Optional[RedTeamError] = None
329+
330+
def __enter__(self):
331+
return self
332+
333+
def __exit__(self, exc_type, exc_val, exc_tb):
334+
if exc_val is not None:
335+
self.error = self.handler.handle_exception(
336+
exception=exc_val, context=self.context, task_name=self.task_name, reraise=False
337+
)
338+
339+
# Reraise fatal errors unless specifically disabled
340+
if self.reraise_fatal and self.error.severity == ErrorSeverity.FATAL:
341+
raise self.error
342+
343+
# Suppress the original exception (we've handled it)
344+
return True
345+
return False

0 commit comments

Comments
 (0)