Skip to content

Commit 27baf20

Browse files
authored
Red team ensure coverage across risk sub categories (#43636)
* redteam bug fixes for model usage computation and auth * redteam ensure coverage across risk sub categories * updates * add tests * updates * fix tests * formatting * updates * updates * updates
1 parent 9ac1d8a commit 27baf20

File tree

5 files changed

+443
-53
lines changed

5 files changed

+443
-53
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_attack_objective_generator.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,22 @@ def _load_and_validate_custom_prompts(self) -> None:
238238
}
239239
self.logger.info(f"Prompt distribution by risk category: {category_counts}")
240240

241-
# Automatically extract risk categories from valid prompts if not provided
242-
if not self.risk_categories:
243-
categories_with_prompts = [cat for cat, prompts in self.valid_prompts_by_category.items() if prompts]
244-
self.risk_categories = [RiskCategory(cat) for cat in categories_with_prompts]
241+
# Merge risk categories from custom prompts with explicitly provided risk_categories
242+
categories_with_prompts = [cat for cat, prompts in self.valid_prompts_by_category.items() if prompts]
243+
categories_from_prompts = [RiskCategory(cat) for cat in categories_with_prompts]
244+
245+
if self.risk_categories:
246+
# Combine explicitly provided categories with those from custom prompts
247+
combined_categories = list(set(self.risk_categories + categories_from_prompts))
248+
self.logger.info(
249+
f"Merging provided risk categories {[cat.value for cat in self.risk_categories]} "
250+
f"with categories from custom prompts {[cat.value for cat in categories_from_prompts]} "
251+
f"-> Combined: {[cat.value for cat in combined_categories]}"
252+
)
253+
self.risk_categories = combined_categories
254+
else:
255+
# No risk categories provided, use only those from custom prompts
256+
self.risk_categories = categories_from_prompts
245257
self.logger.info(
246258
f"Automatically set risk categories based on valid prompts: {[cat.value for cat in self.risk_categories]}"
247259
)

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_red_team.py

Lines changed: 138 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
from pyrit.prompt_target import PromptChatTarget
6060

6161
# Local imports - constants and utilities
62-
from ._utils.constants import TASK_STATUS
62+
from ._utils.constants import TASK_STATUS, MAX_SAMPLING_ITERATIONS_MULTIPLIER, RISK_TO_NUM_SUBTYPE_MAP
6363
from ._utils.logging_utils import (
6464
setup_logger,
6565
log_section_header,
@@ -76,6 +76,7 @@
7676
from ._utils.retry_utils import create_standard_retry_manager
7777
from ._utils.file_utils import create_file_manager
7878
from ._utils.metric_mapping import get_attack_objective_from_risk_category
79+
from ._utils.objective_utils import extract_risk_subtype, get_objective_id
7980

8081
from ._orchestrator_manager import OrchestratorManager
8182
from ._evaluation_processor import EvaluationProcessor
@@ -352,9 +353,20 @@ async def _get_attack_objectives(
352353
risk_cat_value = get_attack_objective_from_risk_category(risk_category).lower()
353354
num_objectives = attack_objective_generator.num_objectives
354355

356+
# Calculate num_objectives_with_subtypes based on max subtypes across all risk categories
357+
# Use attack_objective_generator.risk_categories as self.risk_categories may not be set yet
358+
risk_categories = getattr(self, "risk_categories", None) or attack_objective_generator.risk_categories
359+
max_num_subtypes = max((RISK_TO_NUM_SUBTYPE_MAP.get(rc, 0) for rc in risk_categories), default=0)
360+
num_objectives_with_subtypes = max(num_objectives, max_num_subtypes)
361+
362+
self.logger.debug(
363+
f"Calculated num_objectives_with_subtypes for {risk_cat_value}: "
364+
f"max(num_objectives={num_objectives}, max_subtypes={max_num_subtypes}) = {num_objectives_with_subtypes}"
365+
)
366+
355367
log_subsection_header(
356368
self.logger,
357-
f"Getting attack objectives for {risk_cat_value}, strategy: {strategy}",
369+
f"Getting attack objectives for {risk_cat_value}, strategy: {strategy}, num_objectives: {num_objectives}, num_objectives_with_subtypes: {num_objectives_with_subtypes}",
358370
)
359371

360372
# Check if we already have baseline objectives for this risk category
@@ -370,7 +382,7 @@ async def _get_attack_objectives(
370382
if custom_objectives:
371383
# Use custom objectives for this risk category
372384
return await self._get_custom_attack_objectives(
373-
risk_cat_value, num_objectives, strategy, current_key, is_agent_target
385+
risk_cat_value, num_objectives, num_objectives_with_subtypes, strategy, current_key, is_agent_target
374386
)
375387
else:
376388
# No custom objectives for this risk category, but risk_categories was specified
@@ -391,7 +403,9 @@ async def _get_attack_objectives(
391403
baseline_key,
392404
current_key,
393405
num_objectives,
406+
num_objectives_with_subtypes,
394407
is_agent_target,
408+
client_id,
395409
)
396410
else:
397411
# Risk category not in requested list, return empty
@@ -409,6 +423,7 @@ async def _get_attack_objectives(
409423
baseline_key,
410424
current_key,
411425
num_objectives,
426+
num_objectives_with_subtypes,
412427
is_agent_target,
413428
client_id,
414429
)
@@ -417,6 +432,7 @@ async def _get_custom_attack_objectives(
417432
self,
418433
risk_cat_value: str,
419434
num_objectives: int,
435+
num_objectives_with_subtypes: int,
420436
strategy: str,
421437
current_key: tuple,
422438
is_agent_target: Optional[bool] = None,
@@ -437,15 +453,97 @@ async def _get_custom_attack_objectives(
437453

438454
self.logger.info(f"Found {len(custom_objectives)} custom objectives for {risk_cat_value}")
439455

440-
# Sample if we have more than needed
441-
if len(custom_objectives) > num_objectives:
442-
selected_cat_objectives = random.sample(custom_objectives, num_objectives)
456+
# Deduplicate objectives by ID to avoid selecting the same logical objective multiple times
457+
seen_ids = set()
458+
deduplicated_objectives = []
459+
for obj in custom_objectives:
460+
obj_id = get_objective_id(obj)
461+
if obj_id not in seen_ids:
462+
seen_ids.add(obj_id)
463+
deduplicated_objectives.append(obj)
464+
465+
if len(deduplicated_objectives) < len(custom_objectives):
466+
self.logger.debug(
467+
f"Deduplicated {len(custom_objectives)} objectives to {len(deduplicated_objectives)} unique objectives by ID"
468+
)
469+
470+
# Group objectives by risk_subtype if present
471+
objectives_by_subtype = {}
472+
objectives_without_subtype = []
473+
474+
for obj in deduplicated_objectives:
475+
risk_subtype = extract_risk_subtype(obj)
476+
477+
if risk_subtype:
478+
if risk_subtype not in objectives_by_subtype:
479+
objectives_by_subtype[risk_subtype] = []
480+
objectives_by_subtype[risk_subtype].append(obj)
481+
else:
482+
objectives_without_subtype.append(obj)
483+
484+
# Determine sampling strategy based on risk_subtype presence
485+
# Use num_objectives_with_subtypes for initial sampling to ensure coverage
486+
if objectives_by_subtype:
487+
# We have risk subtypes - sample evenly across them
488+
num_subtypes = len(objectives_by_subtype)
489+
objectives_per_subtype = max(1, num_objectives_with_subtypes // num_subtypes)
490+
443491
self.logger.info(
444-
f"Sampled {num_objectives} objectives from {len(custom_objectives)} available for {risk_cat_value}"
492+
f"Found {num_subtypes} risk subtypes in custom objectives. "
493+
f"Sampling {objectives_per_subtype} objectives per subtype to reach ~{num_objectives_with_subtypes} total."
445494
)
495+
496+
selected_cat_objectives = []
497+
for subtype, subtype_objectives in objectives_by_subtype.items():
498+
num_to_sample = min(objectives_per_subtype, len(subtype_objectives))
499+
sampled = random.sample(subtype_objectives, num_to_sample)
500+
selected_cat_objectives.extend(sampled)
501+
self.logger.debug(
502+
f"Sampled {num_to_sample} objectives from risk_subtype '{subtype}' "
503+
f"({len(subtype_objectives)} available)"
504+
)
505+
506+
# If we need more objectives to reach num_objectives_with_subtypes, sample from objectives without subtype
507+
if len(selected_cat_objectives) < num_objectives_with_subtypes and objectives_without_subtype:
508+
remaining = num_objectives_with_subtypes - len(selected_cat_objectives)
509+
num_to_sample = min(remaining, len(objectives_without_subtype))
510+
selected_cat_objectives.extend(random.sample(objectives_without_subtype, num_to_sample))
511+
self.logger.debug(f"Added {num_to_sample} objectives without risk_subtype to reach target count")
512+
513+
# If we still need more, round-robin through subtypes again
514+
if len(selected_cat_objectives) < num_objectives_with_subtypes:
515+
remaining = num_objectives_with_subtypes - len(selected_cat_objectives)
516+
subtype_list = list(objectives_by_subtype.keys())
517+
# Track selected objective IDs in a set for O(1) membership checks
518+
# Use the objective's 'id' field if available, generate UUID-based ID otherwise
519+
selected_ids = {get_objective_id(obj) for obj in selected_cat_objectives}
520+
idx = 0
521+
while remaining > 0 and subtype_list:
522+
subtype = subtype_list[idx % len(subtype_list)]
523+
available = [
524+
obj for obj in objectives_by_subtype[subtype] if get_objective_id(obj) not in selected_ids
525+
]
526+
if available:
527+
selected_obj = random.choice(available)
528+
selected_cat_objectives.append(selected_obj)
529+
selected_ids.add(get_objective_id(selected_obj))
530+
remaining -= 1
531+
idx += 1
532+
# Prevent infinite loop if we run out of unique objectives
533+
if idx > len(subtype_list) * MAX_SAMPLING_ITERATIONS_MULTIPLIER:
534+
break
535+
536+
self.logger.info(f"Sampled {len(selected_cat_objectives)} objectives across {num_subtypes} risk subtypes")
446537
else:
447-
selected_cat_objectives = custom_objectives
448-
self.logger.info(f"Using all {len(custom_objectives)} available objectives for {risk_cat_value}")
538+
# No risk subtypes - use num_objectives_with_subtypes for sampling
539+
if len(custom_objectives) > num_objectives_with_subtypes:
540+
selected_cat_objectives = random.sample(custom_objectives, num_objectives_with_subtypes)
541+
self.logger.info(
542+
f"Sampled {num_objectives_with_subtypes} objectives from {len(custom_objectives)} available for {risk_cat_value}"
543+
)
544+
else:
545+
selected_cat_objectives = custom_objectives
546+
self.logger.info(f"Using all {len(custom_objectives)} available objectives for {risk_cat_value}")
449547
target_type_str = "agent" if is_agent_target else "model" if is_agent_target is not None else None
450548
# Handle jailbreak strategy - need to apply jailbreak prefixes to messages
451549
if strategy == "jailbreak":
@@ -456,17 +554,8 @@ async def _get_custom_attack_objectives(
456554
# Extract content from selected objectives
457555
selected_prompts = []
458556
for obj in selected_cat_objectives:
459-
risk_subtype = None
460557
# Extract risk-subtype from target_harms if present
461-
target_harms = obj.get("metadata", {}).get("target_harms", [])
462-
if target_harms and isinstance(target_harms, list):
463-
for harm in target_harms:
464-
if isinstance(harm, dict) and "risk-subtype" in harm:
465-
subtype_value = harm.get("risk-subtype")
466-
# Only store non-empty risk-subtype values
467-
if subtype_value and subtype_value.strip():
468-
risk_subtype = subtype_value
469-
break # Use the first non-empty risk-subtype found
558+
risk_subtype = extract_risk_subtype(obj)
470559

471560
if "messages" in obj and len(obj["messages"]) > 0:
472561
message = obj["messages"][0]
@@ -494,6 +583,7 @@ async def _get_rai_attack_objectives(
494583
baseline_key: tuple,
495584
current_key: tuple,
496585
num_objectives: int,
586+
num_objectives_with_subtypes: int,
497587
is_agent_target: Optional[bool] = None,
498588
client_id: Optional[str] = None,
499589
) -> List[str]:
@@ -533,9 +623,8 @@ async def _get_rai_attack_objectives(
533623
objectives_response = await self._apply_xpia_prompts(objectives_response, target_type_str)
534624

535625
except Exception as e:
536-
self.logger.error(f"Error calling get_attack_objectives: {str(e)}")
537-
self.logger.warning("API call failed, returning empty objectives list")
538-
return []
626+
self.logger.warning(f"Error calling get_attack_objectives: {str(e)}")
627+
objectives_response = {}
539628

540629
# Check if the response is valid
541630
if not objectives_response or (
@@ -585,9 +674,9 @@ async def _get_rai_attack_objectives(
585674
self.logger.warning("Empty or invalid response, returning empty list")
586675
return []
587676

588-
# Filter and select objectives
677+
# Filter and select objectives using num_objectives_with_subtypes
589678
selected_cat_objectives = self._filter_and_select_objectives(
590-
objectives_response, strategy, baseline_objectives_exist, baseline_key, num_objectives
679+
objectives_response, strategy, baseline_objectives_exist, baseline_key, num_objectives_with_subtypes
591680
)
592681

593682
# Extract content and cache
@@ -845,6 +934,12 @@ def _filter_and_select_objectives(
845934
# This is the baseline strategy or we don't have baseline objectives yet
846935
self.logger.debug(f"Using random selection for {strategy} strategy")
847936
selected_cat_objectives = random.sample(objectives_response, min(num_objectives, len(objectives_response)))
937+
selection_msg = (
938+
f"Selected {len(selected_cat_objectives)} objectives using num_objectives={num_objectives} "
939+
f"(available: {len(objectives_response)})"
940+
)
941+
self.logger.info(selection_msg)
942+
tqdm.write(f"[INFO] {selection_msg}")
848943

849944
if len(selected_cat_objectives) < num_objectives:
850945
self.logger.warning(
@@ -857,16 +952,7 @@ def _extract_objective_content(self, selected_objectives: List) -> List[str]:
857952
"""Extract content from selected objectives and build prompt-to-context mapping."""
858953
selected_prompts = []
859954
for obj in selected_objectives:
860-
risk_subtype = None
861-
# Extract risk-subtype from target_harms if present
862-
target_harms = obj.get("metadata", {}).get("target_harms", [])
863-
if target_harms and isinstance(target_harms, list):
864-
for harm in target_harms:
865-
if isinstance(harm, dict) and "risk-subtype" in harm:
866-
subtype_value = harm.get("risk-subtype")
867-
if subtype_value:
868-
risk_subtype = subtype_value
869-
break
955+
risk_subtype = extract_risk_subtype(obj)
870956
if "messages" in obj and len(obj["messages"]) > 0:
871957
message = obj["messages"][0]
872958
if isinstance(message, dict) and "content" in message:
@@ -953,20 +1039,9 @@ def _cache_attack_objectives(
9531039
# Process list format and organize by category for caching
9541040
for obj in selected_objectives:
9551041
obj_id = obj.get("id", f"obj-{uuid.uuid4()}")
956-
target_harms = obj.get("metadata", {}).get("target_harms", [])
9571042
content = ""
9581043
context = ""
959-
risk_subtype = None
960-
961-
# Extract risk-subtype from target_harms if present
962-
if target_harms and isinstance(target_harms, list):
963-
for harm in target_harms:
964-
if isinstance(harm, dict) and "risk-subtype" in harm:
965-
subtype_value = harm.get("risk-subtype")
966-
# Only store non-empty risk-subtype values
967-
if subtype_value:
968-
risk_subtype = subtype_value
969-
break # Use the first non-empty risk-subtype found
1044+
risk_subtype = extract_risk_subtype(obj)
9701045

9711046
if "messages" in obj and len(obj["messages"]) > 0:
9721047

@@ -1400,6 +1475,19 @@ async def _fetch_all_objectives(
14001475
log_section_header(self.logger, "Fetching attack objectives")
14011476
all_objectives = {}
14021477

1478+
# Calculate and log num_objectives_with_subtypes once globally
1479+
num_objectives = self.attack_objective_generator.num_objectives
1480+
max_num_subtypes = max((RISK_TO_NUM_SUBTYPE_MAP.get(rc, 0) for rc in self.risk_categories), default=0)
1481+
num_objectives_with_subtypes = max(num_objectives, max_num_subtypes)
1482+
1483+
if num_objectives_with_subtypes != num_objectives:
1484+
warning_msg = (
1485+
f"Using {num_objectives_with_subtypes} objectives per risk category instead of requested {num_objectives} "
1486+
f"to ensure adequate coverage of {max_num_subtypes} subtypes"
1487+
)
1488+
self.logger.warning(warning_msg)
1489+
tqdm.write(f"[WARNING] {warning_msg}")
1490+
14031491
# First fetch baseline objectives for all risk categories
14041492
self.logger.info("Fetching baseline objectives for all risk categories")
14051493
for risk_category in self.risk_categories:
@@ -1413,9 +1501,10 @@ async def _fetch_all_objectives(
14131501
if "baseline" not in all_objectives:
14141502
all_objectives["baseline"] = {}
14151503
all_objectives["baseline"][risk_category.value] = baseline_objectives
1416-
tqdm.write(
1417-
f"📝 Fetched baseline objectives for {risk_category.value}: {len(baseline_objectives)} objectives"
1418-
)
1504+
status_msg = f"📝 Fetched baseline objectives for {risk_category.value}: {len(baseline_objectives)}/{num_objectives_with_subtypes} objectives"
1505+
if len(baseline_objectives) < num_objectives_with_subtypes:
1506+
status_msg += f" (⚠️ fewer than expected)"
1507+
tqdm.write(status_msg)
14191508

14201509
# Then fetch objectives for other strategies
14211510
strategy_count = len(flattened_attack_strategies)

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_utils/constants.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,20 @@
4747
# Task timeouts and status codes
4848
INTERNAL_TASK_TIMEOUT = 120
4949

50+
# Sampling constants
51+
# Multiplier for the maximum number of sampling iterations when round-robin sampling from risk subtypes.
52+
# This prevents infinite loops while allowing sufficient attempts to find unique objectives.
53+
# With N subtypes, this allows up to N * MAX_SAMPLING_ITERATIONS_MULTIPLIER total iterations.
54+
MAX_SAMPLING_ITERATIONS_MULTIPLIER = 100
55+
56+
# Map of risk categories to their maximum number of subtypes
57+
# Used to calculate num_objectives_with_subtypes for adequate subtype coverage
58+
RISK_TO_NUM_SUBTYPE_MAP = {
59+
RiskCategory.ProhibitedActions: 32,
60+
RiskCategory.TaskAdherence: 9,
61+
RiskCategory.SensitiveDataLeakage: 19,
62+
}
63+
5064
# Task status definitions
5165
TASK_STATUS = {
5266
"PENDING": "pending",

0 commit comments

Comments
 (0)