Skip to content

Commit 8db6767

Browse files
authored
[RedTeam] Use memory labels for better timeout/error handling (Azure#40420)
* init * improving logic for timeout handling with memory labels * use keyword args for _write_pyrit_outputs_to_file * add memory label logic for single-batch prompt processing
1 parent 8b2fc6a commit 8db6767

File tree

1 file changed

+59
-26
lines changed
  • sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team

1 file changed

+59
-26
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_red_team.py

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from pyrit.common import initialize_pyrit, DUCK_DB
5252
from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget
5353
from pyrit.models import ChatMessage
54+
from pyrit.memory import CentralMemory
5455
from pyrit.orchestrator.single_turn.prompt_sending_orchestrator import PromptSendingOrchestrator
5556
from pyrit.orchestrator import Orchestrator
5657
from pyrit.exceptions import PyritException
@@ -667,6 +668,17 @@ async def _prompt_sending_orchestrator(
667668
# Use a batched approach for send_prompts_async to prevent overwhelming
668669
# the model with too many concurrent requests
669670
batch_size = min(len(all_prompts), 3) # Process 3 prompts at a time max
671+
672+
# Initialize output path for memory labelling
673+
base_path = str(uuid.uuid4())
674+
675+
# If scan output directory exists, place the file there
676+
if hasattr(self, 'scan_output_dir') and self.scan_output_dir:
677+
output_path = os.path.join(self.scan_output_dir, f"{base_path}{DATA_EXT}")
678+
else:
679+
output_path = f"{base_path}{DATA_EXT}"
680+
681+
self.red_team_info[strategy_name][risk_category]["data_file"] = output_path
670682

671683
# Process prompts concurrently within each batch
672684
if len(all_prompts) > batch_size:
@@ -681,8 +693,8 @@ async def _prompt_sending_orchestrator(
681693
try:
682694
# Use wait_for to implement a timeout
683695
await asyncio.wait_for(
684-
orchestrator.send_prompts_async(prompt_list=batch),
685-
timeout=timeout # Use provided timeout
696+
orchestrator.send_prompts_async(prompt_list=batch, memory_labels = {"risk_strategy_path": output_path, "batch": batch_idx+1}),
697+
timeout=timeout # Use provided timeouts
686698
)
687699
batch_duration = (datetime.now() - batch_start_time).total_seconds()
688700
self.logger.debug(f"Successfully processed batch {batch_idx+1} for {strategy_name}/{risk_category} in {batch_duration:.2f} seconds")
@@ -699,12 +711,14 @@ async def _prompt_sending_orchestrator(
699711
batch_task_key = f"{strategy_name}_{risk_category}_batch_{batch_idx+1}"
700712
self.task_statuses[batch_task_key] = TASK_STATUS["TIMEOUT"]
701713
self.red_team_info[strategy_name][risk_category]["status"] = TASK_STATUS["INCOMPLETE"]
714+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category, batch_idx=batch_idx+1)
702715
# Continue with partial results rather than failing completely
703716
continue
704717
except Exception as e:
705718
log_error(self.logger, f"Error processing batch {batch_idx+1}", e, f"{strategy_name}/{risk_category}")
706719
self.logger.debug(f"ERROR: Strategy {strategy_name}, Risk {risk_category}, Batch {batch_idx+1}: {str(e)}")
707720
self.red_team_info[strategy_name][risk_category]["status"] = TASK_STATUS["INCOMPLETE"]
721+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category, batch_idx=batch_idx+1)
708722
# Continue with other batches even if one fails
709723
continue
710724
else:
@@ -713,7 +727,7 @@ async def _prompt_sending_orchestrator(
713727
batch_start_time = datetime.now()
714728
try:
715729
await asyncio.wait_for(
716-
orchestrator.send_prompts_async(prompt_list=all_prompts),
730+
orchestrator.send_prompts_async(prompt_list=all_prompts, memory_labels = {"risk_strategy_path": output_path, "batch": 1}),
717731
timeout=timeout # Use provided timeout
718732
)
719733
batch_duration = (datetime.now() - batch_start_time).total_seconds()
@@ -725,10 +739,12 @@ async def _prompt_sending_orchestrator(
725739
single_batch_task_key = f"{strategy_name}_{risk_category}_single_batch"
726740
self.task_statuses[single_batch_task_key] = TASK_STATUS["TIMEOUT"]
727741
self.red_team_info[strategy_name][risk_category]["status"] = TASK_STATUS["INCOMPLETE"]
742+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category, batch_idx=batch_idx+1)
728743
except Exception as e:
729744
log_error(self.logger, "Error processing prompts", e, f"{strategy_name}/{risk_category}")
730745
self.logger.debug(f"ERROR: Strategy {strategy_name}, Risk {risk_category}: {str(e)}")
731746
self.red_team_info[strategy_name][risk_category]["status"] = TASK_STATUS["INCOMPLETE"]
747+
self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category, batch_idx=batch_idx+1)
732748

733749
self.task_statuses[task_key] = TASK_STATUS["COMPLETED"]
734750
return orchestrator
@@ -739,39 +755,55 @@ async def _prompt_sending_orchestrator(
739755
self.task_statuses[task_key] = TASK_STATUS["FAILED"]
740756
raise
741757

742-
def _write_pyrit_outputs_to_file(self, orchestrator: Orchestrator) -> str:
758+
def _write_pyrit_outputs_to_file(self,*, orchestrator: Orchestrator, strategy_name: str, risk_category: str, batch_idx: Optional[int] = None) -> str:
743759
"""Write PyRIT outputs to a file with a name based on orchestrator, converter, and risk category.
744760
745761
:param orchestrator: The orchestrator that generated the outputs
746762
:type orchestrator: Orchestrator
747763
:return: Path to the output file
748764
:rtype: Union[str, os.PathLike]
749765
"""
750-
base_path = str(uuid.uuid4())
751-
752-
# If scan output directory exists, place the file there
753-
if hasattr(self, 'scan_output_dir') and self.scan_output_dir:
754-
output_path = os.path.join(self.scan_output_dir, f"{base_path}{DATA_EXT}")
755-
else:
756-
output_path = f"{base_path}{DATA_EXT}"
757-
766+
output_path = self.red_team_info[strategy_name][risk_category]["data_file"]
758767
self.logger.debug(f"Writing PyRIT outputs to file: {output_path}")
768+
memory = CentralMemory.get_memory_instance()
759769

760-
memory = orchestrator.get_memory()
761-
762-
# Get conversations as a List[List[ChatMessage]]
763-
conversations = [[item.to_chat_message() for item in group] for conv_id, group in itertools.groupby(memory, key=lambda x: x.conversation_id)]
764-
765-
#Convert to json lines
766-
json_lines = ""
767-
for conversation in conversations: # each conversation is a List[ChatMessage]
768-
json_lines += json.dumps({"conversation": {"messages": [self._message_to_dict(message) for message in conversation]}}) + "\n"
770+
memory_label = {"risk_strategy_path": output_path}
769771

770-
with Path(output_path).open("w") as f:
771-
f.writelines(json_lines)
772+
prompts_request_pieces = memory.get_prompt_request_pieces(labels=memory_label)
772773

773-
orchestrator.dispose_db_engine()
774-
self.logger.debug(f"Successfully wrote {len(conversations)} conversations to {output_path}")
774+
conversations = [[item.to_chat_message() for item in group] for conv_id, group in itertools.groupby(prompts_request_pieces, key=lambda x: x.conversation_id)]
775+
# Check if we should overwrite existing file with more conversations
776+
if os.path.exists(output_path):
777+
existing_line_count = 0
778+
try:
779+
with open(output_path, 'r') as existing_file:
780+
existing_line_count = sum(1 for _ in existing_file)
781+
782+
# Use the number of prompts to determine if we have more conversations
783+
# This is more accurate than using the memory which might have incomplete conversations
784+
if len(conversations) > existing_line_count:
785+
self.logger.debug(f"Found more prompts ({len(conversations)}) than existing file lines ({existing_line_count}). Replacing content.")
786+
#Convert to json lines
787+
json_lines = ""
788+
for conversation in conversations: # each conversation is a List[ChatMessage]
789+
json_lines += json.dumps({"conversation": {"messages": [self._message_to_dict(message) for message in conversation]}}) + "\n"
790+
with Path(output_path).open("w") as f:
791+
f.writelines(json_lines)
792+
self.logger.debug(f"Successfully wrote {len(conversations)-existing_line_count} new conversation(s) to {output_path}")
793+
else:
794+
self.logger.debug(f"Existing file has {existing_line_count} lines, new data has {len(conversations)} prompts. Keeping existing file.")
795+
return output_path
796+
except Exception as e:
797+
self.logger.warning(f"Failed to read existing file {output_path}: {str(e)}")
798+
else:
799+
self.logger.debug(f"Creating new file: {output_path}")
800+
#Convert to json lines
801+
json_lines = ""
802+
for conversation in conversations: # each conversation is a List[ChatMessage]
803+
json_lines += json.dumps({"conversation": {"messages": [self._message_to_dict(message) for message in conversation]}}) + "\n"
804+
with Path(output_path).open("w") as f:
805+
f.writelines(json_lines)
806+
self.logger.debug(f"Successfully wrote {len(conversations)} conversations to {output_path}")
775807
return str(output_path)
776808

777809
# Replace with utility function
@@ -1379,7 +1411,8 @@ async def _process_attack(
13791411
progress_bar.update(1)
13801412
return None
13811413

1382-
data_path = self._write_pyrit_outputs_to_file(orchestrator)
1414+
data_path = self._write_pyrit_outputs_to_file(orchestrator=orchestrator, strategy_name=strategy_name, risk_category=risk_category.value)
1415+
orchestrator.dispose_db_engine()
13831416

13841417
# Store data file in our tracking dictionary
13851418
self.red_team_info[strategy_name][risk_category.value]["data_file"] = data_path

0 commit comments

Comments
 (0)