5959from pyrit .prompt_target import PromptChatTarget
6060
6161# Local imports - constants and utilities
62- from ._utils .constants import TASK_STATUS
62+ from ._utils .constants import TASK_STATUS , MAX_SAMPLING_ITERATIONS_MULTIPLIER , RISK_TO_NUM_SUBTYPE_MAP
6363from ._utils .logging_utils import (
6464 setup_logger ,
6565 log_section_header ,
7676from ._utils .retry_utils import create_standard_retry_manager
7777from ._utils .file_utils import create_file_manager
7878from ._utils .metric_mapping import get_attack_objective_from_risk_category
79+ from ._utils .objective_utils import extract_risk_subtype , get_objective_id
7980
8081from ._orchestrator_manager import OrchestratorManager
8182from ._evaluation_processor import EvaluationProcessor
@@ -352,9 +353,20 @@ async def _get_attack_objectives(
352353 risk_cat_value = get_attack_objective_from_risk_category (risk_category ).lower ()
353354 num_objectives = attack_objective_generator .num_objectives
354355
356+ # Calculate num_objectives_with_subtypes based on max subtypes across all risk categories
357+ # Use attack_objective_generator.risk_categories as self.risk_categories may not be set yet
358+ risk_categories = getattr (self , "risk_categories" , None ) or attack_objective_generator .risk_categories
359+ max_num_subtypes = max ((RISK_TO_NUM_SUBTYPE_MAP .get (rc , 0 ) for rc in risk_categories ), default = 0 )
360+ num_objectives_with_subtypes = max (num_objectives , max_num_subtypes )
361+
362+ self .logger .debug (
363+ f"Calculated num_objectives_with_subtypes for { risk_cat_value } : "
364+ f"max(num_objectives={ num_objectives } , max_subtypes={ max_num_subtypes } ) = { num_objectives_with_subtypes } "
365+ )
366+
355367 log_subsection_header (
356368 self .logger ,
357- f"Getting attack objectives for { risk_cat_value } , strategy: { strategy } " ,
369+ f"Getting attack objectives for { risk_cat_value } , strategy: { strategy } , num_objectives: { num_objectives } , num_objectives_with_subtypes: { num_objectives_with_subtypes } " ,
358370 )
359371
360372 # Check if we already have baseline objectives for this risk category
@@ -370,7 +382,7 @@ async def _get_attack_objectives(
370382 if custom_objectives :
371383 # Use custom objectives for this risk category
372384 return await self ._get_custom_attack_objectives (
373- risk_cat_value , num_objectives , strategy , current_key , is_agent_target
385+ risk_cat_value , num_objectives , num_objectives_with_subtypes , strategy , current_key , is_agent_target
374386 )
375387 else :
376388 # No custom objectives for this risk category, but risk_categories was specified
@@ -391,7 +403,9 @@ async def _get_attack_objectives(
391403 baseline_key ,
392404 current_key ,
393405 num_objectives ,
406+ num_objectives_with_subtypes ,
394407 is_agent_target ,
408+ client_id ,
395409 )
396410 else :
397411 # Risk category not in requested list, return empty
@@ -409,6 +423,7 @@ async def _get_attack_objectives(
409423 baseline_key ,
410424 current_key ,
411425 num_objectives ,
426+ num_objectives_with_subtypes ,
412427 is_agent_target ,
413428 client_id ,
414429 )
@@ -417,6 +432,7 @@ async def _get_custom_attack_objectives(
417432 self ,
418433 risk_cat_value : str ,
419434 num_objectives : int ,
435+ num_objectives_with_subtypes : int ,
420436 strategy : str ,
421437 current_key : tuple ,
422438 is_agent_target : Optional [bool ] = None ,
@@ -437,15 +453,97 @@ async def _get_custom_attack_objectives(
437453
438454 self .logger .info (f"Found { len (custom_objectives )} custom objectives for { risk_cat_value } " )
439455
440- # Sample if we have more than needed
441- if len (custom_objectives ) > num_objectives :
442- selected_cat_objectives = random .sample (custom_objectives , num_objectives )
456+ # Deduplicate objectives by ID to avoid selecting the same logical objective multiple times
457+ seen_ids = set ()
458+ deduplicated_objectives = []
459+ for obj in custom_objectives :
460+ obj_id = get_objective_id (obj )
461+ if obj_id not in seen_ids :
462+ seen_ids .add (obj_id )
463+ deduplicated_objectives .append (obj )
464+
465+ if len (deduplicated_objectives ) < len (custom_objectives ):
466+ self .logger .debug (
467+ f"Deduplicated { len (custom_objectives )} objectives to { len (deduplicated_objectives )} unique objectives by ID"
468+ )
469+
470+ # Group objectives by risk_subtype if present
471+ objectives_by_subtype = {}
472+ objectives_without_subtype = []
473+
474+ for obj in deduplicated_objectives :
475+ risk_subtype = extract_risk_subtype (obj )
476+
477+ if risk_subtype :
478+ if risk_subtype not in objectives_by_subtype :
479+ objectives_by_subtype [risk_subtype ] = []
480+ objectives_by_subtype [risk_subtype ].append (obj )
481+ else :
482+ objectives_without_subtype .append (obj )
483+
484+ # Determine sampling strategy based on risk_subtype presence
485+ # Use num_objectives_with_subtypes for initial sampling to ensure coverage
486+ if objectives_by_subtype :
487+ # We have risk subtypes - sample evenly across them
488+ num_subtypes = len (objectives_by_subtype )
489+ objectives_per_subtype = max (1 , num_objectives_with_subtypes // num_subtypes )
490+
443491 self .logger .info (
444- f"Sampled { num_objectives } objectives from { len (custom_objectives )} available for { risk_cat_value } "
492+ f"Found { num_subtypes } risk subtypes in custom objectives. "
493+ f"Sampling { objectives_per_subtype } objectives per subtype to reach ~{ num_objectives_with_subtypes } total."
445494 )
495+
496+ selected_cat_objectives = []
497+ for subtype , subtype_objectives in objectives_by_subtype .items ():
498+ num_to_sample = min (objectives_per_subtype , len (subtype_objectives ))
499+ sampled = random .sample (subtype_objectives , num_to_sample )
500+ selected_cat_objectives .extend (sampled )
501+ self .logger .debug (
502+ f"Sampled { num_to_sample } objectives from risk_subtype '{ subtype } ' "
503+ f"({ len (subtype_objectives )} available)"
504+ )
505+
506+ # If we need more objectives to reach num_objectives_with_subtypes, sample from objectives without subtype
507+ if len (selected_cat_objectives ) < num_objectives_with_subtypes and objectives_without_subtype :
508+ remaining = num_objectives_with_subtypes - len (selected_cat_objectives )
509+ num_to_sample = min (remaining , len (objectives_without_subtype ))
510+ selected_cat_objectives .extend (random .sample (objectives_without_subtype , num_to_sample ))
511+ self .logger .debug (f"Added { num_to_sample } objectives without risk_subtype to reach target count" )
512+
513+ # If we still need more, round-robin through subtypes again
514+ if len (selected_cat_objectives ) < num_objectives_with_subtypes :
515+ remaining = num_objectives_with_subtypes - len (selected_cat_objectives )
516+ subtype_list = list (objectives_by_subtype .keys ())
517+ # Track selected objective IDs in a set for O(1) membership checks
518+ # Use the objective's 'id' field if available, generate UUID-based ID otherwise
519+ selected_ids = {get_objective_id (obj ) for obj in selected_cat_objectives }
520+ idx = 0
521+ while remaining > 0 and subtype_list :
522+ subtype = subtype_list [idx % len (subtype_list )]
523+ available = [
524+ obj for obj in objectives_by_subtype [subtype ] if get_objective_id (obj ) not in selected_ids
525+ ]
526+ if available :
527+ selected_obj = random .choice (available )
528+ selected_cat_objectives .append (selected_obj )
529+ selected_ids .add (get_objective_id (selected_obj ))
530+ remaining -= 1
531+ idx += 1
532+ # Prevent infinite loop if we run out of unique objectives
533+ if idx > len (subtype_list ) * MAX_SAMPLING_ITERATIONS_MULTIPLIER :
534+ break
535+
536+ self .logger .info (f"Sampled { len (selected_cat_objectives )} objectives across { num_subtypes } risk subtypes" )
446537 else :
447- selected_cat_objectives = custom_objectives
448- self .logger .info (f"Using all { len (custom_objectives )} available objectives for { risk_cat_value } " )
538+ # No risk subtypes - use num_objectives_with_subtypes for sampling
539+ if len (custom_objectives ) > num_objectives_with_subtypes :
540+ selected_cat_objectives = random .sample (custom_objectives , num_objectives_with_subtypes )
541+ self .logger .info (
542+ f"Sampled { num_objectives_with_subtypes } objectives from { len (custom_objectives )} available for { risk_cat_value } "
543+ )
544+ else :
545+ selected_cat_objectives = custom_objectives
546+ self .logger .info (f"Using all { len (custom_objectives )} available objectives for { risk_cat_value } " )
449547 target_type_str = "agent" if is_agent_target else "model" if is_agent_target is not None else None
450548 # Handle jailbreak strategy - need to apply jailbreak prefixes to messages
451549 if strategy == "jailbreak" :
@@ -456,17 +554,8 @@ async def _get_custom_attack_objectives(
456554 # Extract content from selected objectives
457555 selected_prompts = []
458556 for obj in selected_cat_objectives :
459- risk_subtype = None
460557 # Extract risk-subtype from target_harms if present
461- target_harms = obj .get ("metadata" , {}).get ("target_harms" , [])
462- if target_harms and isinstance (target_harms , list ):
463- for harm in target_harms :
464- if isinstance (harm , dict ) and "risk-subtype" in harm :
465- subtype_value = harm .get ("risk-subtype" )
466- # Only store non-empty risk-subtype values
467- if subtype_value and subtype_value .strip ():
468- risk_subtype = subtype_value
469- break # Use the first non-empty risk-subtype found
558+ risk_subtype = extract_risk_subtype (obj )
470559
471560 if "messages" in obj and len (obj ["messages" ]) > 0 :
472561 message = obj ["messages" ][0 ]
@@ -494,6 +583,7 @@ async def _get_rai_attack_objectives(
494583 baseline_key : tuple ,
495584 current_key : tuple ,
496585 num_objectives : int ,
586+ num_objectives_with_subtypes : int ,
497587 is_agent_target : Optional [bool ] = None ,
498588 client_id : Optional [str ] = None ,
499589 ) -> List [str ]:
@@ -533,9 +623,8 @@ async def _get_rai_attack_objectives(
533623 objectives_response = await self ._apply_xpia_prompts (objectives_response , target_type_str )
534624
535625 except Exception as e :
536- self .logger .error (f"Error calling get_attack_objectives: { str (e )} " )
537- self .logger .warning ("API call failed, returning empty objectives list" )
538- return []
626+ self .logger .warning (f"Error calling get_attack_objectives: { str (e )} " )
627+ objectives_response = {}
539628
540629 # Check if the response is valid
541630 if not objectives_response or (
@@ -585,9 +674,9 @@ async def _get_rai_attack_objectives(
585674 self .logger .warning ("Empty or invalid response, returning empty list" )
586675 return []
587676
588- # Filter and select objectives
677+ # Filter and select objectives using num_objectives_with_subtypes
589678 selected_cat_objectives = self ._filter_and_select_objectives (
590- objectives_response , strategy , baseline_objectives_exist , baseline_key , num_objectives
679+ objectives_response , strategy , baseline_objectives_exist , baseline_key , num_objectives_with_subtypes
591680 )
592681
593682 # Extract content and cache
@@ -845,6 +934,12 @@ def _filter_and_select_objectives(
845934 # This is the baseline strategy or we don't have baseline objectives yet
846935 self .logger .debug (f"Using random selection for { strategy } strategy" )
847936 selected_cat_objectives = random .sample (objectives_response , min (num_objectives , len (objectives_response )))
937+ selection_msg = (
938+ f"Selected { len (selected_cat_objectives )} objectives using num_objectives={ num_objectives } "
939+ f"(available: { len (objectives_response )} )"
940+ )
941+ self .logger .info (selection_msg )
942+ tqdm .write (f"[INFO] { selection_msg } " )
848943
849944 if len (selected_cat_objectives ) < num_objectives :
850945 self .logger .warning (
@@ -857,16 +952,7 @@ def _extract_objective_content(self, selected_objectives: List) -> List[str]:
857952 """Extract content from selected objectives and build prompt-to-context mapping."""
858953 selected_prompts = []
859954 for obj in selected_objectives :
860- risk_subtype = None
861- # Extract risk-subtype from target_harms if present
862- target_harms = obj .get ("metadata" , {}).get ("target_harms" , [])
863- if target_harms and isinstance (target_harms , list ):
864- for harm in target_harms :
865- if isinstance (harm , dict ) and "risk-subtype" in harm :
866- subtype_value = harm .get ("risk-subtype" )
867- if subtype_value :
868- risk_subtype = subtype_value
869- break
955+ risk_subtype = extract_risk_subtype (obj )
870956 if "messages" in obj and len (obj ["messages" ]) > 0 :
871957 message = obj ["messages" ][0 ]
872958 if isinstance (message , dict ) and "content" in message :
@@ -953,20 +1039,9 @@ def _cache_attack_objectives(
9531039 # Process list format and organize by category for caching
9541040 for obj in selected_objectives :
9551041 obj_id = obj .get ("id" , f"obj-{ uuid .uuid4 ()} " )
956- target_harms = obj .get ("metadata" , {}).get ("target_harms" , [])
9571042 content = ""
9581043 context = ""
959- risk_subtype = None
960-
961- # Extract risk-subtype from target_harms if present
962- if target_harms and isinstance (target_harms , list ):
963- for harm in target_harms :
964- if isinstance (harm , dict ) and "risk-subtype" in harm :
965- subtype_value = harm .get ("risk-subtype" )
966- # Only store non-empty risk-subtype values
967- if subtype_value :
968- risk_subtype = subtype_value
969- break # Use the first non-empty risk-subtype found
1044+ risk_subtype = extract_risk_subtype (obj )
9701045
9711046 if "messages" in obj and len (obj ["messages" ]) > 0 :
9721047
@@ -1400,6 +1475,19 @@ async def _fetch_all_objectives(
14001475 log_section_header (self .logger , "Fetching attack objectives" )
14011476 all_objectives = {}
14021477
1478+ # Calculate and log num_objectives_with_subtypes once globally
1479+ num_objectives = self .attack_objective_generator .num_objectives
1480+ max_num_subtypes = max ((RISK_TO_NUM_SUBTYPE_MAP .get (rc , 0 ) for rc in self .risk_categories ), default = 0 )
1481+ num_objectives_with_subtypes = max (num_objectives , max_num_subtypes )
1482+
1483+ if num_objectives_with_subtypes != num_objectives :
1484+ warning_msg = (
1485+ f"Using { num_objectives_with_subtypes } objectives per risk category instead of requested { num_objectives } "
1486+ f"to ensure adequate coverage of { max_num_subtypes } subtypes"
1487+ )
1488+ self .logger .warning (warning_msg )
1489+ tqdm .write (f"[WARNING] { warning_msg } " )
1490+
14031491 # First fetch baseline objectives for all risk categories
14041492 self .logger .info ("Fetching baseline objectives for all risk categories" )
14051493 for risk_category in self .risk_categories :
@@ -1413,9 +1501,10 @@ async def _fetch_all_objectives(
14131501 if "baseline" not in all_objectives :
14141502 all_objectives ["baseline" ] = {}
14151503 all_objectives ["baseline" ][risk_category .value ] = baseline_objectives
1416- tqdm .write (
1417- f"📝 Fetched baseline objectives for { risk_category .value } : { len (baseline_objectives )} objectives"
1418- )
1504+ status_msg = f"📝 Fetched baseline objectives for { risk_category .value } : { len (baseline_objectives )} /{ num_objectives_with_subtypes } objectives"
1505+ if len (baseline_objectives ) < num_objectives_with_subtypes :
1506+ status_msg += f" (⚠️ fewer than expected)"
1507+ tqdm .write (status_msg )
14191508
14201509 # Then fetch objectives for other strategies
14211510 strategy_count = len (flattened_attack_strategies )
0 commit comments