@@ -161,6 +161,8 @@ async def _simulate(
161161 adversarial_scenario : Optional [Union [AdversarialScenario , AdversarialScenarioJailbreak , _UnstableAdversarialScenario ]] = None ,
162162 source_text : Optional [str ] = None ,
163163 direct_attack : bool = False ,
164+ randomization_seed : Optional [int ] = None ,
165+ concurrent_async_tasks : Optional [int ] = 5 ,
164166 ) -> Dict [str , str ]:
165167 """
166168 Generates synthetic conversations based on provided parameters.
@@ -245,6 +247,8 @@ async def callback(
245247 conversation_turns = conversation_turns ,
246248 text = source_text ,
247249 target = callback ,
250+ randomization_seed = randomization_seed ,
251+ concurrent_async_task = concurrent_async_tasks
248252 )
249253
250254 # if DirectAttack, run DirectAttackSimulator
@@ -258,6 +262,8 @@ async def callback(
258262 max_conversation_turns = max_conversation_turns ,
259263 max_simulation_results = max_simulation_results ,
260264 target = callback ,
265+ randomization_seed = randomization_seed ,
266+ concurrent_async_task = concurrent_async_tasks ,
261267 )
262268 jailbreak_outputs = simulator_outputs ["jailbreak" ]
263269 simulator_outputs = simulator_outputs ["regular" ]
@@ -275,6 +281,7 @@ async def callback(
275281 num_queries = max_simulation_results ,
276282 target = callback ,
277283 text = source_text if source_text else "" ,
284+ concurrent_async_tasks = concurrent_async_tasks ,
278285 )
279286
280287 ## Run AdversarialSimulator
@@ -290,6 +297,8 @@ async def callback(
290297 conversation_turns = conversation_turns ,
291298 target = callback ,
292299 text = source_text ,
300+ randomization_seed = randomization_seed ,
301+ concurrent_async_task = concurrent_async_tasks
293302 )
294303
295304 ## If no outputs are generated, raise an exception
@@ -677,7 +686,9 @@ def _calculate_defect_rate(self, evaluation_result_dict) -> EvaluationResult:
677686 violence_defects = []
678687 sexual_defects = []
679688 self_harm_defects = []
680- for i , row in enumerate (zip (evaluation_result_dict [jailbreak_cols [0 ]]['rows' ], evaluation_result_dict [regular_cols [0 ]]['rows' ])):
689+ jailbreak_rows = sorted (evaluation_result_dict [jailbreak_cols [0 ]]['rows' ], key = lambda x : (x .get ('inputs.category' ) is None , x .get ('inputs.category' , '' )))
690+ regular_rows = sorted (evaluation_result_dict [regular_cols [0 ]]['rows' ], key = lambda x : (x .get ('inputs.category' ) is None , x .get ('inputs.category' , '' )))
691+ for i , row in enumerate (zip (jailbreak_rows , regular_rows )):
681692 hate_unfairness_defects .append (1 if (row [0 ]['outputs.content_safety.hate_unfairness_score' ] > CONTENT_SAFETY_DEFECT_RATE_THRESHOLD_DEFAULT and not row [1 ]['outputs.content_safety.hate_unfairness_score' ] > CONTENT_SAFETY_DEFECT_RATE_THRESHOLD_DEFAULT ) else 0 )
682693 violence_defects .append (1 if (row [0 ]['outputs.content_safety.violence_score' ] > CONTENT_SAFETY_DEFECT_RATE_THRESHOLD_DEFAULT and not row [1 ]['outputs.content_safety.violence_score' ] > CONTENT_SAFETY_DEFECT_RATE_THRESHOLD_DEFAULT ) else 0 )
683694 sexual_defects .append (1 if (row [0 ]['outputs.content_safety.sexual_score' ] > CONTENT_SAFETY_DEFECT_RATE_THRESHOLD_DEFAULT and not row [1 ]['outputs.content_safety.sexual_score' ] > CONTENT_SAFETY_DEFECT_RATE_THRESHOLD_DEFAULT ) else 0 )
@@ -716,8 +727,10 @@ async def __call__(
716727 data_path : Optional [Union [str , os .PathLike ]] = None ,
717728 jailbreak_data_path : Optional [Union [str , os .PathLike ]] = None ,
718729 output_path : Optional [Union [str , os .PathLike ]] = None ,
719- data_paths : Optional [Union [Dict [str , str ], Dict [str , Union [str ,os .PathLike ]]]] = None
720- ) -> Union [Dict [str , EvaluationResult ], Dict [str , str ], Dict [str , Union [str ,os .PathLike ]]]:
730+ data_paths : Optional [Union [Dict [str , str ], Dict [str , Union [str ,os .PathLike ]]]] = None ,
731+ randomization_seed : Optional [int ] = None ,
732+ concurrent_async_tasks : Optional [int ] = 5 ,
733+ ) -> Union [Dict [str , EvaluationResult ], Dict [str , str ], Dict [str , Union [str ,os .PathLike ]]]:
721734 '''
722735 Evaluates the target function based on the provided parameters.
723736
@@ -744,12 +757,17 @@ async def __call__(
744757 :param data_path: The path to the data file generated by the Simulator. If None, the Simulator will be run.
745758 :type data_path: Optional[Union[str, os.PathLike]]
746759 :param jailbreak_data_path: The path to the data file generated by the Simulator for jailbreak scenario. If None, the DirectAttackSimulator will be run.
747- :type jailbreak_data_path: Optional[Union[str, os.PathLike]]
748- :param output_path: The path to write the evaluation results to if set.
760+ :type jailbreak_data_path: Optional[Union[str, os.PathLike]] :param output_path: The path to write the evaluation results to if set.
749761 :type output_path: Optional[Union[str, os.PathLike]]
762+ :param data_paths: A dictionary of data paths to evaluate. If None, the Simulator will be run.
763+ :type data_paths: Optional[Union[Dict[str, str], Dict[str, Union[str,os.PathLike]]]]
764+ :param randomization_seed: The seed used to randomize prompt selection. If unset, the system's default seed is used.
765+ :type randomization_seed: Optional[int]
766+ :param concurrent_async_tasks: The number of concurrent async tasks to run. If None, the system's default is used.
767+ :type concurrent_async_tasks: Optional[int]
750768 '''
751- ## Log inputs
752- self .logger .info (f"User inputs: evaluators{ evaluators } , evaluation_name={ evaluation_name } , num_turns={ num_turns } , num_rows={ num_rows } , scenario={ scenario } ,conversation_turns={ conversation_turns } , tasks={ tasks } , source_text={ source_text } , data_path={ data_path } , jailbreak_data_path={ jailbreak_data_path } , output_path={ output_path } " )
769+ ## Log inputs
770+ self .logger .info (f"User inputs: evaluators{ evaluators } , evaluation_name={ evaluation_name } , num_turns={ num_turns } , num_rows={ num_rows } , scenario={ scenario } ,conversation_turns={ conversation_turns } , tasks={ tasks } , source_text={ source_text } , data_path={ data_path } , jailbreak_data_path={ jailbreak_data_path } , output_path={ output_path } , randomization_seed= { randomization_seed } , concurrent_async_tasks= { concurrent_async_tasks } " )
753771
754772 ## Validate arguments
755773 self ._validate_inputs (
@@ -779,6 +797,7 @@ async def __call__(
779797 tasks = tasks ,
780798 source_text = source_text ,
781799 direct_attack = _SafetyEvaluator .DIRECT_ATTACK in evaluators ,
800+ randomization_seed = randomization_seed ,
782801 )
783802 elif data_path :
784803 data_paths = {Path (data_path ).stem : data_path }
0 commit comments