Skip to content

Commit c82fdb4

Browse files
slister1001Nagkumar Arkalgudnagkumar91
authored
Update redteam complexity mapping (Azure#40321)
* Update complexity mapping * remove unused debug_mode param * update strategy converter mapping * get base64 tense working properly * Add api version if its available in model config * fix printing errors * make red_team separate sub-package and fix unit tests * update sample * Update setup.py to pin pyrit and termcolor * Update setup.py * fix unit tests --------- Co-authored-by: Nagkumar Arkalgud <[email protected]> Co-authored-by: Nagkumar Arkalgud <[email protected]>
1 parent 3fe7515 commit c82fdb4

24 files changed

+109
-127
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/__init__.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -52,22 +52,6 @@
5252
except ImportError:
5353
print("[INFO] Could not import AIAgentConverter. Please install the dependency with `pip install azure-ai-projects`.")
5454

55-
# RedTeam requires a dependency on pyrit, but python 3.9 is not supported by pyrit.
56-
# So we only import it if it's available and the user has pyrit.
57-
try:
58-
from ._red_team._red_team import RedTeam
59-
from ._red_team._attack_strategy import AttackStrategy
60-
from ._red_team._attack_objective_generator import RiskCategory
61-
from ._red_team._red_team_result import RedTeamOutput
62-
_patch_all.extend([
63-
"RedTeam",
64-
"RedTeamOutput",
65-
"AttackStrategy",
66-
"RiskCategory",
67-
])
68-
except ImportError:
69-
print("[INFO] Could not import RedTeam. Please install the dependency with `pip install azure-ai-evaluation[redteam]`.")
70-
7155

7256
__all__ = [
7357
"evaluate",

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_red_team/_utils/__init__.py

Lines changed: 0 additions & 3 deletions
This file was deleted.
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# ---------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# ---------------------------------------------------------
4+
5+
from ._red_team import RedTeam
6+
from ._attack_strategy import AttackStrategy
7+
from ._attack_objective_generator import RiskCategory
8+
from ._red_team_result import RedTeamOutput
9+
10+
__all__ = [
11+
"RedTeam",
12+
"AttackStrategy",
13+
"RiskCategory",
14+
"RedTeamOutput",
15+
]

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_red_team/_red_team.py renamed to sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_red_team.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -441,11 +441,11 @@ async def _get_attack_objectives(
441441
self.logger.debug(f"API call: get_attack_objectives({risk_cat_value}, app: {application_scenario}, strategy: {strategy})")
442442
# strategy param specifies whether to get a strategy-specific dataset from the RAI service
443443
# right now, only tense requires strategy-specific dataset
444-
if strategy == "tense":
444+
if "tense" in strategy:
445445
objectives_response = await self.generated_rai_client.get_attack_objectives(
446446
risk_category=risk_cat_value,
447447
application_scenario=application_scenario or "",
448-
strategy=strategy
448+
strategy="tense"
449449
)
450450
else:
451451
objectives_response = await self.generated_rai_client.get_attack_objectives(
@@ -679,7 +679,7 @@ async def _prompt_sending_orchestrator(
679679
continue
680680
except Exception as e:
681681
log_error(self.logger, f"Error processing batch {batch_idx+1}", e, f"{strategy_name}/{risk_category}")
682-
print(f"ERROR: Strategy {strategy_name}, Risk {risk_category}, Batch {batch_idx+1}: {str(e)}")
682+
self.logger.debug(f"ERROR: Strategy {strategy_name}, Risk {risk_category}, Batch {batch_idx+1}: {str(e)}")
683683
# Continue with other batches even if one fails
684684
continue
685685
else:
@@ -701,14 +701,14 @@ async def _prompt_sending_orchestrator(
701701
self.task_statuses[single_batch_task_key] = TASK_STATUS["TIMEOUT"]
702702
except Exception as e:
703703
log_error(self.logger, "Error processing prompts", e, f"{strategy_name}/{risk_category}")
704-
print(f"ERROR: Strategy {strategy_name}, Risk {risk_category}: {str(e)}")
704+
self.logger.debug(f"ERROR: Strategy {strategy_name}, Risk {risk_category}: {str(e)}")
705705

706706
self.task_statuses[task_key] = TASK_STATUS["COMPLETED"]
707707
return orchestrator
708708

709709
except Exception as e:
710710
log_error(self.logger, "Failed to initialize orchestrator", e, f"{strategy_name}/{risk_category}")
711-
print(f"CRITICAL: Failed to create orchestrator for {strategy_name}/{risk_category}: {str(e)}")
711+
self.logger.debug(f"CRITICAL: Failed to create orchestrator for {strategy_name}/{risk_category}: {str(e)}")
712712
self.task_statuses[task_key] = TASK_STATUS["FAILED"]
713713
raise
714714

@@ -1344,7 +1344,7 @@ async def _process_attack(
13441344
orchestrator = await call_orchestrator(self.chat_target, all_prompts, converter, strategy_name, risk_category.value, timeout)
13451345
except PyritException as e:
13461346
log_error(self.logger, f"Error calling orchestrator for {strategy_name} strategy", e)
1347-
print(f"Orchestrator error for {strategy_name}/{risk_category.value}: {str(e)}")
1347+
self.logger.debug(f"Orchestrator error for {strategy_name}/{risk_category.value}: {str(e)}")
13481348
self.task_statuses[task_key] = TASK_STATUS["FAILED"]
13491349
self.failed_tasks += 1
13501350

@@ -1399,7 +1399,7 @@ async def _process_attack(
13991399

14001400
except Exception as e:
14011401
log_error(self.logger, f"Unexpected error processing {strategy_name} strategy for {risk_category.value}", e)
1402-
print(f"Critical error in task {strategy_name}/{risk_category.value}: {str(e)}")
1402+
self.logger.debug(f"Critical error in task {strategy_name}/{risk_category.value}: {str(e)}")
14031403
self.task_statuses[task_key] = TASK_STATUS["FAILED"]
14041404
self.failed_tasks += 1
14051405

@@ -1419,7 +1419,6 @@ async def scan(
14191419
application_scenario: Optional[str] = None,
14201420
parallel_execution: bool = True,
14211421
max_parallel_tasks: int = 5,
1422-
debug_mode: bool = False,
14231422
timeout: int = 120) -> RedTeamOutput:
14241423
"""Run a red team scan against the target using the specified strategies.
14251424
@@ -1441,8 +1440,6 @@ async def scan(
14411440
:type parallel_execution: bool
14421441
:param max_parallel_tasks: Maximum number of parallel orchestrator tasks to run (default: 5)
14431442
:type max_parallel_tasks: int
1444-
:param debug_mode: Whether to run in debug mode (more verbose output)
1445-
:type debug_mode: bool
14461443
:param timeout: The timeout in seconds for API calls (default: 120)
14471444
:type timeout: int
14481445
:return: The output from the red team scan
@@ -1522,7 +1519,7 @@ def filter(self, record):
15221519
if not self.attack_objective_generator:
15231520
error_msg = "Attack objective generator is required for red team agent."
15241521
log_error(self.logger, error_msg)
1525-
print(f"{error_msg}")
1522+
self.logger.debug(f"{error_msg}")
15261523
raise EvaluationException(
15271524
message=error_msg,
15281525
internal_message="Attack objective generator is not provided.",
@@ -1676,17 +1673,13 @@ def filter(self, record):
16761673
for risk_category in self.risk_categories:
16771674
progress_bar.set_postfix({"current": f"fetching {strategy_name}/{risk_category.value}"})
16781675
self.logger.debug(f"Fetching objectives for {strategy_name} strategy and {risk_category.value} risk category")
1679-
16801676
objectives = await self._get_attack_objectives(
16811677
risk_category=risk_category,
16821678
application_scenario=application_scenario,
16831679
strategy=strategy_name
16841680
)
16851681
all_objectives[strategy_name][risk_category.value] = objectives
16861682

1687-
# Print status about objective count for this strategy/risk
1688-
if debug_mode:
1689-
print(f" - {risk_category.value}: {len(objectives)} objectives")
16901683

16911684
self.logger.info("Completed fetching all attack objectives")
16921685

@@ -1754,7 +1747,7 @@ def filter(self, record):
17541747
continue
17551748
except Exception as e:
17561749
log_error(self.logger, f"Error processing batch {i//max_parallel_tasks+1}", e)
1757-
print(f"Error in batch {i//max_parallel_tasks+1}: {str(e)}")
1750+
self.logger.debug(f"Error in batch {i//max_parallel_tasks+1}: {str(e)}")
17581751
continue
17591752
else:
17601753
# Sequential execution
@@ -1776,7 +1769,7 @@ def filter(self, record):
17761769
continue
17771770
except Exception as e:
17781771
log_error(self.logger, f"Error processing task {i+1}/{len(orchestrator_tasks)}", e)
1779-
print(f"Error in task {i+1}: {str(e)}")
1772+
self.logger.debug(f"Error in task {i+1}: {str(e)}")
17801773
continue
17811774

17821775
progress_bar.close()

0 commit comments

Comments
 (0)