5151from pyrit .common import initialize_pyrit , DUCK_DB
5252from pyrit .prompt_target import OpenAIChatTarget , PromptChatTarget
5353from pyrit .models import ChatMessage
54+ from pyrit .memory import CentralMemory
5455from pyrit .orchestrator .single_turn .prompt_sending_orchestrator import PromptSendingOrchestrator
5556from pyrit .orchestrator import Orchestrator
5657from pyrit .exceptions import PyritException
@@ -667,6 +668,17 @@ async def _prompt_sending_orchestrator(
667668 # Use a batched approach for send_prompts_async to prevent overwhelming
668669 # the model with too many concurrent requests
669670 batch_size = min (len (all_prompts ), 3 ) # Process 3 prompts at a time max
671+
672+ # Initialize output path for memory labelling
673+ base_path = str (uuid .uuid4 ())
674+
675+ # If scan output directory exists, place the file there
676+ if hasattr (self , 'scan_output_dir' ) and self .scan_output_dir :
677+ output_path = os .path .join (self .scan_output_dir , f"{ base_path } { DATA_EXT } " )
678+ else :
679+ output_path = f"{ base_path } { DATA_EXT } "
680+
681+ self .red_team_info [strategy_name ][risk_category ]["data_file" ] = output_path
670682
671683 # Process prompts concurrently within each batch
672684 if len (all_prompts ) > batch_size :
@@ -681,8 +693,8 @@ async def _prompt_sending_orchestrator(
681693 try :
682694 # Use wait_for to implement a timeout
683695 await asyncio .wait_for (
684- orchestrator .send_prompts_async (prompt_list = batch ),
685- timeout = timeout # Use provided timeout
696+ orchestrator .send_prompts_async (prompt_list = batch , memory_labels = { "risk_strategy_path" : output_path , "batch" : batch_idx + 1 } ),
697+ timeout = timeout # Use provided timeouts
686698 )
687699 batch_duration = (datetime .now () - batch_start_time ).total_seconds ()
688700 self .logger .debug (f"Successfully processed batch { batch_idx + 1 } for { strategy_name } /{ risk_category } in { batch_duration :.2f} seconds" )
@@ -699,12 +711,14 @@ async def _prompt_sending_orchestrator(
699711 batch_task_key = f"{ strategy_name } _{ risk_category } _batch_{ batch_idx + 1 } "
700712 self .task_statuses [batch_task_key ] = TASK_STATUS ["TIMEOUT" ]
701713 self .red_team_info [strategy_name ][risk_category ]["status" ] = TASK_STATUS ["INCOMPLETE" ]
714+ self ._write_pyrit_outputs_to_file (orchestrator = orchestrator , strategy_name = strategy_name , risk_category = risk_category , batch_idx = batch_idx + 1 )
702715 # Continue with partial results rather than failing completely
703716 continue
704717 except Exception as e :
705718 log_error (self .logger , f"Error processing batch { batch_idx + 1 } " , e , f"{ strategy_name } /{ risk_category } " )
706719 self .logger .debug (f"ERROR: Strategy { strategy_name } , Risk { risk_category } , Batch { batch_idx + 1 } : { str (e )} " )
707720 self .red_team_info [strategy_name ][risk_category ]["status" ] = TASK_STATUS ["INCOMPLETE" ]
721+ self ._write_pyrit_outputs_to_file (orchestrator = orchestrator , strategy_name = strategy_name , risk_category = risk_category , batch_idx = batch_idx + 1 )
708722 # Continue with other batches even if one fails
709723 continue
710724 else :
@@ -713,7 +727,7 @@ async def _prompt_sending_orchestrator(
713727 batch_start_time = datetime .now ()
714728 try :
715729 await asyncio .wait_for (
716- orchestrator .send_prompts_async (prompt_list = all_prompts ),
730+ orchestrator .send_prompts_async (prompt_list = all_prompts , memory_labels = { "risk_strategy_path" : output_path , "batch" : 1 } ),
717731 timeout = timeout # Use provided timeout
718732 )
719733 batch_duration = (datetime .now () - batch_start_time ).total_seconds ()
@@ -725,10 +739,12 @@ async def _prompt_sending_orchestrator(
725739 single_batch_task_key = f"{ strategy_name } _{ risk_category } _single_batch"
726740 self .task_statuses [single_batch_task_key ] = TASK_STATUS ["TIMEOUT" ]
727741 self .red_team_info [strategy_name ][risk_category ]["status" ] = TASK_STATUS ["INCOMPLETE" ]
742+ self ._write_pyrit_outputs_to_file (orchestrator = orchestrator , strategy_name = strategy_name , risk_category = risk_category , batch_idx = batch_idx + 1 )
728743 except Exception as e :
729744 log_error (self .logger , "Error processing prompts" , e , f"{ strategy_name } /{ risk_category } " )
730745 self .logger .debug (f"ERROR: Strategy { strategy_name } , Risk { risk_category } : { str (e )} " )
731746 self .red_team_info [strategy_name ][risk_category ]["status" ] = TASK_STATUS ["INCOMPLETE" ]
747+ self ._write_pyrit_outputs_to_file (orchestrator = orchestrator , strategy_name = strategy_name , risk_category = risk_category , batch_idx = batch_idx + 1 )
732748
733749 self .task_statuses [task_key ] = TASK_STATUS ["COMPLETED" ]
734750 return orchestrator
@@ -739,39 +755,55 @@ async def _prompt_sending_orchestrator(
739755 self .task_statuses [task_key ] = TASK_STATUS ["FAILED" ]
740756 raise
741757
742- def _write_pyrit_outputs_to_file (self , orchestrator : Orchestrator ) -> str :
758+ def _write_pyrit_outputs_to_file (self ,* , orchestrator : Orchestrator , strategy_name : str , risk_category : str , batch_idx : Optional [ int ] = None ) -> str :
743759 """Write PyRIT outputs to a file with a name based on orchestrator, converter, and risk category.
744760
745761 :param orchestrator: The orchestrator that generated the outputs
746762 :type orchestrator: Orchestrator
747763 :return: Path to the output file
748764 :rtype: Union[str, os.PathLike]
749765 """
750- base_path = str (uuid .uuid4 ())
751-
752- # If scan output directory exists, place the file there
753- if hasattr (self , 'scan_output_dir' ) and self .scan_output_dir :
754- output_path = os .path .join (self .scan_output_dir , f"{ base_path } { DATA_EXT } " )
755- else :
756- output_path = f"{ base_path } { DATA_EXT } "
757-
766+ output_path = self .red_team_info [strategy_name ][risk_category ]["data_file" ]
758767 self .logger .debug (f"Writing PyRIT outputs to file: { output_path } " )
768+ memory = CentralMemory .get_memory_instance ()
759769
760- memory = orchestrator .get_memory ()
761-
762- # Get conversations as a List[List[ChatMessage]]
763- conversations = [[item .to_chat_message () for item in group ] for conv_id , group in itertools .groupby (memory , key = lambda x : x .conversation_id )]
764-
765- #Convert to json lines
766- json_lines = ""
767- for conversation in conversations : # each conversation is a List[ChatMessage]
768- json_lines += json .dumps ({"conversation" : {"messages" : [self ._message_to_dict (message ) for message in conversation ]}}) + "\n "
770+ memory_label = {"risk_strategy_path" : output_path }
769771
770- with Path (output_path ).open ("w" ) as f :
771- f .writelines (json_lines )
772+ prompts_request_pieces = memory .get_prompt_request_pieces (labels = memory_label )
772773
773- orchestrator .dispose_db_engine ()
774- self .logger .debug (f"Successfully wrote { len (conversations )} conversations to { output_path } " )
774+ conversations = [[item .to_chat_message () for item in group ] for conv_id , group in itertools .groupby (prompts_request_pieces , key = lambda x : x .conversation_id )]
775+ # Check if we should overwrite existing file with more conversations
776+ if os .path .exists (output_path ):
777+ existing_line_count = 0
778+ try :
779+ with open (output_path , 'r' ) as existing_file :
780+ existing_line_count = sum (1 for _ in existing_file )
781+
782+ # Use the number of prompts to determine if we have more conversations
783+ # This is more accurate than using the memory which might have incomplete conversations
784+ if len (conversations ) > existing_line_count :
785+ self .logger .debug (f"Found more prompts ({ len (conversations )} ) than existing file lines ({ existing_line_count } ). Replacing content." )
786+ #Convert to json lines
787+ json_lines = ""
788+ for conversation in conversations : # each conversation is a List[ChatMessage]
789+ json_lines += json .dumps ({"conversation" : {"messages" : [self ._message_to_dict (message ) for message in conversation ]}}) + "\n "
790+ with Path (output_path ).open ("w" ) as f :
791+ f .writelines (json_lines )
792+ self .logger .debug (f"Successfully wrote { len (conversations )- existing_line_count } new conversation(s) to { output_path } " )
793+ else :
794+ self .logger .debug (f"Existing file has { existing_line_count } lines, new data has { len (conversations )} prompts. Keeping existing file." )
795+ return output_path
796+ except Exception as e :
797+ self .logger .warning (f"Failed to read existing file { output_path } : { str (e )} " )
798+ else :
799+ self .logger .debug (f"Creating new file: { output_path } " )
800+ #Convert to json lines
801+ json_lines = ""
802+ for conversation in conversations : # each conversation is a List[ChatMessage]
803+ json_lines += json .dumps ({"conversation" : {"messages" : [self ._message_to_dict (message ) for message in conversation ]}}) + "\n "
804+ with Path (output_path ).open ("w" ) as f :
805+ f .writelines (json_lines )
806+ self .logger .debug (f"Successfully wrote { len (conversations )} conversations to { output_path } " )
775807 return str (output_path )
776808
777809 # Replace with utility function
@@ -1379,7 +1411,8 @@ async def _process_attack(
13791411 progress_bar .update (1 )
13801412 return None
13811413
1382- data_path = self ._write_pyrit_outputs_to_file (orchestrator )
1414+ data_path = self ._write_pyrit_outputs_to_file (orchestrator = orchestrator , strategy_name = strategy_name , risk_category = risk_category .value )
1415+ orchestrator .dispose_db_engine ()
13831416
13841417 # Store data file in our tracking dictionary
13851418 self .red_team_info [strategy_name ][risk_category .value ]["data_file" ] = data_path
0 commit comments