Skip to content

Commit 0ca7073

Browse files
Copilotslister1001Copilot
authored
Add context support to red_team attack objectives for enhanced evaluation (#42128)
* Initial plan * Implement context support in red_team attack objectives and evaluation Co-authored-by: slister1001 <[email protected]> * Add documentation and examples for red_team context support Co-authored-by: slister1001 <[email protected]> * Apply black formatting to red_team files Co-authored-by: slister1001 <[email protected]> * Update sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_red_team.py Co-authored-by: Copilot <[email protected]> * Update sdk/evaluation/azure-ai-evaluation/samples/red_team_samples.py Co-authored-by: Copilot <[email protected]> * Add ISA and XPIA risk categories with similar processing to ProtectedMaterial Co-authored-by: slister1001 <[email protected]> * Change RiskCategory.ISA to RiskCategory.UngroundedAttributes Co-authored-by: slister1001 <[email protected]> * updates * updates * small indentation fix * run code formatter * updates for callback * run code formatter * code formatter * updates for multiturn and crescendo with context * small fix for crescendo * remove _internalriskcategory typehints * fix code format --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: slister1001 <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: Sydney Lister <[email protected]>
1 parent c4f30f2 commit 0ca7073

18 files changed

+430
-83
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/_common/onedp/models/_models.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1961,19 +1961,24 @@ class Message(_Model):
19611961
:vartype role: str
19621962
:ivar content: The content.
19631963
:vartype content: str
1964+
:ivar context: The context.
1965+
:vartype context: str
19641966
"""
19651967

19661968
role: Optional[str] = rest_field(name="Role", visibility=["read", "create", "update", "delete", "query"])
19671969
"""The role."""
19681970
content: Optional[str] = rest_field(name="Content", visibility=["read", "create", "update", "delete", "query"])
19691971
"""The content."""
1972+
context: Optional[str] = rest_field(name="Context", visibility=["read", "create", "update", "delete", "query"])
1973+
"""The context."""
19701974

19711975
@overload
19721976
def __init__(
19731977
self,
19741978
*,
19751979
role: Optional[str] = None,
19761980
content: Optional[str] = None,
1981+
context: Optional[str] = None,
19771982
) -> None: ...
19781983

19791984
@overload

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_attack_objective_generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ class RiskCategory(str, Enum):
2020
SelfHarm = "self_harm"
2121
ProtectedMaterial = "protected_material"
2222
CodeVulnerability = "code_vulnerability"
23+
UngroundedAttributes = "ungrounded_attributes"
24+
XPIA = "xpia"
2325

2426

2527
@experimental

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_callback_chat_target.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
*,
2020
callback: Callable[[List[Dict], bool, Optional[str], Optional[Dict[str, Any]]], Dict],
2121
stream: bool = False,
22+
prompt_to_context: Optional[Dict[str, str]] = None,
2223
) -> None:
2324
"""
2425
Initializes an instance of the _CallbackChatTarget class.
@@ -32,10 +33,12 @@ def __init__(
3233
Args:
3334
callback (Callable): The callback function that sends a prompt to a target and receives a response.
3435
stream (bool, optional): Indicates whether the target supports streaming. Defaults to False.
36+
prompt_to_context (Optional[Dict[str, str]], optional): Mapping from prompt content to context. Defaults to None.
3537
"""
3638
PromptChatTarget.__init__(self)
3739
self._callback = callback
3840
self._stream = stream
41+
self._prompt_to_context = prompt_to_context or {}
3942

4043
async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> PromptRequestResponse:
4144

@@ -48,8 +51,18 @@ async def send_prompt_async(self, *, prompt_request: PromptRequestResponse) -> P
4851

4952
logger.info(f"Sending the following prompt to the prompt target: {request}")
5053

54+
# Get context for the current prompt if available
55+
current_prompt_content = request.converted_value
56+
context_data = self._prompt_to_context.get(current_prompt_content, "")
57+
context_dict = {"context": context_data} if context_data else {}
58+
59+
# If context is not available via prompt_to_context, it can be fetched from the memory
60+
if not context_dict:
61+
memory_label_context = request.labels.get("context", None)
62+
context_dict = {"context": memory_label_context} if memory_label_context else {}
63+
5164
# response_context contains "messages", "stream", "session_state, "context"
52-
response_context = await self._callback(messages=messages, stream=self._stream, session_state=None, context=None) # type: ignore
65+
response_context = await self._callback(messages=messages, stream=self._stream, session_state=None, context=context_dict) # type: ignore
5366

5467
response_text = response_context["messages"][-1]["content"]
5568
response_entry = construct_response_from_request(request=request, response_text_pieces=[response_text])

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_evaluation_processor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@
3232
from ._attack_strategy import AttackStrategy
3333
from ._attack_objective_generator import RiskCategory
3434
from ._utils.constants import RESULTS_EXT, TASK_STATUS
35-
from ._utils.metric_mapping import get_annotation_task_from_risk_category, get_metric_from_risk_category
35+
from ._utils.metric_mapping import (
36+
get_annotation_task_from_risk_category,
37+
get_metric_from_risk_category,
38+
get_attack_objective_from_risk_category,
39+
)
3640
from ._utils.logging_utils import log_error
3741
from ._utils.formatting_utils import get_strategy_name
3842

@@ -97,12 +101,19 @@ async def evaluate_conversation(
97101
# Extract all assistant messages for evaluation
98102
assistant_messages = [msg["content"] for msg in messages if msg.get("role") == "assistant"]
99103

104+
context = [msg["context"] for msg in messages if msg.get("role") == "user"]
105+
100106
if assistant_messages:
101107
# Create query-response pair with empty query and all assistant messages
102108
query_response = {
103109
"query": "query",
104110
"response": " ".join(assistant_messages),
105111
}
112+
113+
# Add context to query_response if found
114+
if context[0] is not None:
115+
query_response["context"] = context[0]
116+
106117
try:
107118
self.logger.debug(f"Evaluating conversation {idx+1} for {risk_category.value}/{strategy_name}")
108119

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_orchestrator_manager.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
from ._utils._rai_service_eval_chat_target import RAIServiceEvalChatTarget
3939
from ._utils.constants import DATA_EXT, TASK_STATUS
4040
from ._utils.logging_utils import log_strategy_start, log_error
41+
from ._utils.formatting_utils import write_pyrit_outputs_to_file
4142

4243

4344
def network_retry_decorator(retry_config, logger, strategy_name, risk_category_name, prompt_idx=None):
@@ -169,6 +170,7 @@ async def _prompt_sending_orchestrator(
169170
timeout: int = 120,
170171
red_team_info: Dict = None,
171172
task_statuses: Dict = None,
173+
prompt_to_context: Dict[str, str] = None,
172174
) -> Orchestrator:
173175
"""Send prompts via the PromptSendingOrchestrator.
174176
@@ -224,9 +226,6 @@ async def _prompt_sending_orchestrator(
224226
task_statuses[task_key] = TASK_STATUS["COMPLETED"]
225227
return orchestrator
226228

227-
# Debug log the first few characters of each prompt
228-
self.logger.debug(f"First prompt (truncated): {all_prompts[0][:50]}...")
229-
230229
# Initialize output path for memory labelling
231230
base_path = str(uuid.uuid4())
232231

@@ -313,6 +312,7 @@ async def _multi_turn_orchestrator(
313312
timeout: int = 120,
314313
red_team_info: Dict = None,
315314
task_statuses: Dict = None,
315+
prompt_to_context: Dict[str, str] = None,
316316
) -> Orchestrator:
317317
"""Send prompts via the RedTeamingOrchestrator (multi-turn orchestrator).
318318
@@ -381,6 +381,7 @@ async def _multi_turn_orchestrator(
381381
for prompt_idx, prompt in enumerate(all_prompts):
382382
prompt_start_time = datetime.now()
383383
self.logger.debug(f"Processing prompt {prompt_idx+1}/{len(all_prompts)}")
384+
context = prompt_to_context.get(prompt, None) if prompt_to_context else None
384385
try:
385386
azure_rai_service_scorer = AzureRAIServiceTrueFalseScorer(
386387
client=self.generated_rai_client,
@@ -390,6 +391,7 @@ async def _multi_turn_orchestrator(
390391
credential=self.credential,
391392
risk_category=risk_category,
392393
azure_ai_project=self.azure_ai_project,
394+
context=context,
393395
)
394396

395397
azure_rai_service_target = AzureRAIServiceTarget(
@@ -411,9 +413,6 @@ async def _multi_turn_orchestrator(
411413
use_score_as_feedback=False,
412414
)
413415

414-
# Debug log the first few characters of the current prompt
415-
self.logger.debug(f"Current prompt (truncated): {prompt[:50]}...")
416-
417416
try:
418417
# Create retry-enabled function using the reusable decorator
419418
@network_retry_decorator(
@@ -423,10 +422,7 @@ async def send_prompt_with_retry():
423422
return await asyncio.wait_for(
424423
orchestrator.run_attack_async(
425424
objective=prompt,
426-
memory_labels={
427-
"risk_strategy_path": output_path,
428-
"batch": 1,
429-
},
425+
memory_labels={"risk_strategy_path": output_path, "batch": 1, "context": context},
430426
),
431427
timeout=calculated_timeout,
432428
)
@@ -438,6 +434,13 @@ async def send_prompt_with_retry():
438434
f"Successfully processed prompt {prompt_idx+1} for {strategy_name}/{risk_category_name} in {prompt_duration:.2f} seconds"
439435
)
440436

437+
# Write outputs to file after each prompt is processed
438+
write_pyrit_outputs_to_file(
439+
output_path=output_path,
440+
logger=self.logger,
441+
prompt_to_context=prompt_to_context,
442+
)
443+
441444
# Print progress to console
442445
if prompt_idx < len(all_prompts) - 1: # Don't print for the last prompt
443446
print(
@@ -492,6 +495,7 @@ async def _crescendo_orchestrator(
492495
timeout: int = 120,
493496
red_team_info: Dict = None,
494497
task_statuses: Dict = None,
498+
prompt_to_context: Dict[str, str] = None,
495499
) -> Orchestrator:
496500
"""Send prompts via the CrescendoOrchestrator with optimized performance.
497501
@@ -542,12 +546,14 @@ async def _crescendo_orchestrator(
542546
for prompt_idx, prompt in enumerate(all_prompts):
543547
prompt_start_time = datetime.now()
544548
self.logger.debug(f"Processing prompt {prompt_idx+1}/{len(all_prompts)}")
549+
context = prompt_to_context.get(prompt, None) if prompt_to_context else None
545550
try:
546551
red_llm_scoring_target = RAIServiceEvalChatTarget(
547552
logger=self.logger,
548553
credential=self.credential,
549554
risk_category=risk_category,
550555
azure_ai_project=self.azure_ai_project,
556+
context=context,
551557
)
552558

553559
azure_rai_service_target = AzureRAIServiceTarget(
@@ -577,11 +583,9 @@ async def _crescendo_orchestrator(
577583
credential=self.credential,
578584
risk_category=risk_category,
579585
azure_ai_project=self.azure_ai_project,
586+
context=context,
580587
)
581588

582-
# Debug log the first few characters of the current prompt
583-
self.logger.debug(f"Current prompt (truncated): {prompt[:50]}...")
584-
585589
try:
586590
# Create retry-enabled function using the reusable decorator
587591
@network_retry_decorator(
@@ -594,6 +598,7 @@ async def send_prompt_with_retry():
594598
memory_labels={
595599
"risk_strategy_path": output_path,
596600
"batch": prompt_idx + 1,
601+
"context": context,
597602
},
598603
),
599604
timeout=calculated_timeout,
@@ -606,6 +611,13 @@ async def send_prompt_with_retry():
606611
f"Successfully processed prompt {prompt_idx+1} for {strategy_name}/{risk_category_name} in {prompt_duration:.2f} seconds"
607612
)
608613

614+
# Write outputs to file after each prompt is processed
615+
write_pyrit_outputs_to_file(
616+
output_path=output_path,
617+
logger=self.logger,
618+
prompt_to_context=prompt_to_context,
619+
)
620+
609621
# Print progress to console
610622
if prompt_idx < len(all_prompts) - 1: # Don't print for the last prompt
611623
print(

0 commit comments

Comments
 (0)