38
38
from ._utils ._rai_service_eval_chat_target import RAIServiceEvalChatTarget
39
39
from ._utils .constants import DATA_EXT , TASK_STATUS
40
40
from ._utils .logging_utils import log_strategy_start , log_error
41
+ from ._utils .formatting_utils import write_pyrit_outputs_to_file
41
42
42
43
43
44
def network_retry_decorator (retry_config , logger , strategy_name , risk_category_name , prompt_idx = None ):
@@ -169,6 +170,7 @@ async def _prompt_sending_orchestrator(
169
170
timeout : int = 120 ,
170
171
red_team_info : Dict = None ,
171
172
task_statuses : Dict = None ,
173
+ prompt_to_context : Dict [str , str ] = None ,
172
174
) -> Orchestrator :
173
175
"""Send prompts via the PromptSendingOrchestrator.
174
176
@@ -224,9 +226,6 @@ async def _prompt_sending_orchestrator(
224
226
task_statuses [task_key ] = TASK_STATUS ["COMPLETED" ]
225
227
return orchestrator
226
228
227
- # Debug log the first few characters of each prompt
228
- self .logger .debug (f"First prompt (truncated): { all_prompts [0 ][:50 ]} ..." )
229
-
230
229
# Initialize output path for memory labelling
231
230
base_path = str (uuid .uuid4 ())
232
231
@@ -313,6 +312,7 @@ async def _multi_turn_orchestrator(
313
312
timeout : int = 120 ,
314
313
red_team_info : Dict = None ,
315
314
task_statuses : Dict = None ,
315
+ prompt_to_context : Dict [str , str ] = None ,
316
316
) -> Orchestrator :
317
317
"""Send prompts via the RedTeamingOrchestrator (multi-turn orchestrator).
318
318
@@ -381,6 +381,7 @@ async def _multi_turn_orchestrator(
381
381
for prompt_idx , prompt in enumerate (all_prompts ):
382
382
prompt_start_time = datetime .now ()
383
383
self .logger .debug (f"Processing prompt { prompt_idx + 1 } /{ len (all_prompts )} " )
384
+ context = prompt_to_context .get (prompt , None ) if prompt_to_context else None
384
385
try :
385
386
azure_rai_service_scorer = AzureRAIServiceTrueFalseScorer (
386
387
client = self .generated_rai_client ,
@@ -390,6 +391,7 @@ async def _multi_turn_orchestrator(
390
391
credential = self .credential ,
391
392
risk_category = risk_category ,
392
393
azure_ai_project = self .azure_ai_project ,
394
+ context = context ,
393
395
)
394
396
395
397
azure_rai_service_target = AzureRAIServiceTarget (
@@ -411,9 +413,6 @@ async def _multi_turn_orchestrator(
411
413
use_score_as_feedback = False ,
412
414
)
413
415
414
- # Debug log the first few characters of the current prompt
415
- self .logger .debug (f"Current prompt (truncated): { prompt [:50 ]} ..." )
416
-
417
416
try :
418
417
# Create retry-enabled function using the reusable decorator
419
418
@network_retry_decorator (
@@ -423,10 +422,7 @@ async def send_prompt_with_retry():
423
422
return await asyncio .wait_for (
424
423
orchestrator .run_attack_async (
425
424
objective = prompt ,
426
- memory_labels = {
427
- "risk_strategy_path" : output_path ,
428
- "batch" : 1 ,
429
- },
425
+ memory_labels = {"risk_strategy_path" : output_path , "batch" : 1 , "context" : context },
430
426
),
431
427
timeout = calculated_timeout ,
432
428
)
@@ -438,6 +434,13 @@ async def send_prompt_with_retry():
438
434
f"Successfully processed prompt { prompt_idx + 1 } for { strategy_name } /{ risk_category_name } in { prompt_duration :.2f} seconds"
439
435
)
440
436
437
+ # Write outputs to file after each prompt is processed
438
+ write_pyrit_outputs_to_file (
439
+ output_path = output_path ,
440
+ logger = self .logger ,
441
+ prompt_to_context = prompt_to_context ,
442
+ )
443
+
441
444
# Print progress to console
442
445
if prompt_idx < len (all_prompts ) - 1 : # Don't print for the last prompt
443
446
print (
@@ -492,6 +495,7 @@ async def _crescendo_orchestrator(
492
495
timeout : int = 120 ,
493
496
red_team_info : Dict = None ,
494
497
task_statuses : Dict = None ,
498
+ prompt_to_context : Dict [str , str ] = None ,
495
499
) -> Orchestrator :
496
500
"""Send prompts via the CrescendoOrchestrator with optimized performance.
497
501
@@ -542,12 +546,14 @@ async def _crescendo_orchestrator(
542
546
for prompt_idx , prompt in enumerate (all_prompts ):
543
547
prompt_start_time = datetime .now ()
544
548
self .logger .debug (f"Processing prompt { prompt_idx + 1 } /{ len (all_prompts )} " )
549
+ context = prompt_to_context .get (prompt , None ) if prompt_to_context else None
545
550
try :
546
551
red_llm_scoring_target = RAIServiceEvalChatTarget (
547
552
logger = self .logger ,
548
553
credential = self .credential ,
549
554
risk_category = risk_category ,
550
555
azure_ai_project = self .azure_ai_project ,
556
+ context = context ,
551
557
)
552
558
553
559
azure_rai_service_target = AzureRAIServiceTarget (
@@ -577,11 +583,9 @@ async def _crescendo_orchestrator(
577
583
credential = self .credential ,
578
584
risk_category = risk_category ,
579
585
azure_ai_project = self .azure_ai_project ,
586
+ context = context ,
580
587
)
581
588
582
- # Debug log the first few characters of the current prompt
583
- self .logger .debug (f"Current prompt (truncated): { prompt [:50 ]} ..." )
584
-
585
589
try :
586
590
# Create retry-enabled function using the reusable decorator
587
591
@network_retry_decorator (
@@ -594,6 +598,7 @@ async def send_prompt_with_retry():
594
598
memory_labels = {
595
599
"risk_strategy_path" : output_path ,
596
600
"batch" : prompt_idx + 1 ,
601
+ "context" : context ,
597
602
},
598
603
),
599
604
timeout = calculated_timeout ,
@@ -606,6 +611,13 @@ async def send_prompt_with_retry():
606
611
f"Successfully processed prompt { prompt_idx + 1 } for { strategy_name } /{ risk_category_name } in { prompt_duration :.2f} seconds"
607
612
)
608
613
614
+ # Write outputs to file after each prompt is processed
615
+ write_pyrit_outputs_to_file (
616
+ output_path = output_path ,
617
+ logger = self .logger ,
618
+ prompt_to_context = prompt_to_context ,
619
+ )
620
+
609
621
# Print progress to console
610
622
if prompt_idx < len (all_prompts ) - 1 : # Don't print for the last prompt
611
623
print (
0 commit comments