Skip to content

Commit 9ea0f42

Browse files
slister1001Copilot
andauthored
red team make tense converter dynamic (#43705)
* dynamically get tense * Update sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_utils/strategy_utils.py Co-authored-by: Copilot <[email protected]> * Update sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_utils/strategy_utils.py Co-authored-by: Copilot <[email protected]> * fix tests --------- Co-authored-by: Copilot <[email protected]>
1 parent b149c6e commit 9ea0f42

File tree

5 files changed

+77
-126
lines changed

5 files changed

+77
-126
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_red_team.py

Lines changed: 23 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -498,28 +498,16 @@ async def _get_rai_attack_objectives(
498498
# Get objectives from RAI service
499499
target_type_str = "agent" if is_agent_target else "model" if is_agent_target is not None else None
500500

501-
if "tense" in strategy:
502-
objectives_response = await self.generated_rai_client.get_attack_objectives(
503-
risk_type=content_harm_risk,
504-
risk_category=other_risk,
505-
application_scenario=application_scenario or "",
506-
strategy="tense",
507-
language=self.language.value,
508-
scan_session_id=self.scan_session_id,
509-
target=target_type_str,
510-
client_id=client_id,
511-
)
512-
else:
513-
objectives_response = await self.generated_rai_client.get_attack_objectives(
514-
risk_type=content_harm_risk,
515-
risk_category=other_risk,
516-
application_scenario=application_scenario or "",
517-
strategy=None,
518-
language=self.language.value,
519-
scan_session_id=self.scan_session_id,
520-
target=target_type_str,
521-
client_id=client_id,
522-
)
501+
objectives_response = await self.generated_rai_client.get_attack_objectives(
502+
risk_type=content_harm_risk,
503+
risk_category=other_risk,
504+
application_scenario=application_scenario or "",
505+
strategy=None,
506+
language=self.language.value,
507+
scan_session_id=self.scan_session_id,
508+
target=target_type_str,
509+
client_id=client_id,
510+
)
523511

524512
if isinstance(objectives_response, list):
525513
self.logger.debug(f"API returned {len(objectives_response)} objectives")
@@ -546,28 +534,16 @@ async def _get_rai_attack_objectives(
546534
)
547535
try:
548536
# Retry with model target type
549-
if "tense" in strategy:
550-
objectives_response = await self.generated_rai_client.get_attack_objectives(
551-
risk_type=content_harm_risk,
552-
risk_category=other_risk,
553-
application_scenario=application_scenario or "",
554-
strategy="tense",
555-
language=self.language.value,
556-
scan_session_id=self.scan_session_id,
557-
target="model",
558-
client_id=client_id,
559-
)
560-
else:
561-
objectives_response = await self.generated_rai_client.get_attack_objectives(
562-
risk_type=content_harm_risk,
563-
risk_category=other_risk,
564-
application_scenario=application_scenario or "",
565-
strategy=None,
566-
language=self.language.value,
567-
scan_session_id=self.scan_session_id,
568-
target="model",
569-
client_id=client_id,
570-
)
537+
objectives_response = await self.generated_rai_client.get_attack_objectives(
538+
risk_type=content_harm_risk,
539+
risk_category=other_risk,
540+
application_scenario=application_scenario or "",
541+
strategy=None,
542+
language=self.language.value,
543+
scan_session_id=self.scan_session_id,
544+
target="model",
545+
client_id=client_id,
546+
)
571547

572548
if isinstance(objectives_response, list):
573549
self.logger.debug(f"Fallback API returned {len(objectives_response)} model-type objectives")
@@ -1050,7 +1026,9 @@ async def _process_attack(
10501026
tqdm.write(f"▶️ Starting task: {strategy_name} strategy for {risk_category.value} risk category")
10511027

10521028
# Get converter and orchestrator function
1053-
converter = get_converter_for_strategy(strategy)
1029+
converter = get_converter_for_strategy(
1030+
strategy, self.generated_rai_client, self._one_dp_project, self.logger
1031+
)
10541032
call_orchestrator = self.orchestrator_manager.get_orchestrator_for_attack_strategy(strategy)
10551033

10561034
try:
@@ -1381,15 +1359,6 @@ def _validate_strategies(self, flattened_attack_strategies: List):
13811359
"MultiTurn and Crescendo strategies are not compatible with multiple attack strategies."
13821360
)
13831361
raise ValueError("MultiTurn and Crescendo strategies are not compatible with multiple attack strategies.")
1384-
if AttackStrategy.Tense in flattened_attack_strategies and (
1385-
RiskCategory.UngroundedAttributes in self.risk_categories
1386-
):
1387-
self.logger.warning(
1388-
"Tense strategy is not compatible with UngroundedAttributes risk categories. Skipping Tense strategy."
1389-
)
1390-
raise ValueError(
1391-
"Tense strategy is not compatible with IndirectAttack or UngroundedAttributes risk categories."
1392-
)
13931362

13941363
def _initialize_tracking_dict(self, flattened_attack_strategies: List):
13951364
"""Initialize the red_team_info tracking dictionary."""

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_utils/_rai_service_target.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def __init__(
6161
api_version: Optional[str] = None,
6262
model: Optional[str] = None,
6363
objective: Optional[str] = None,
64+
tense: Optional[str] = None,
6465
prompt_template_key: Optional[str] = None,
6566
logger: Optional[logging.Logger] = None,
6667
crescendo_format: bool = False,
@@ -78,6 +79,7 @@ def __init__(
7879
self._api_version = api_version
7980
self._model = model
8081
self.objective = objective
82+
self.tense = tense
8183
self.prompt_template_key = prompt_template_key
8284
self.logger = logger
8385
self.crescendo_format = crescendo_format
@@ -103,7 +105,6 @@ async def _create_simulation_request(self, prompt: str, objective: str) -> Dict[
103105
"templateParameters": {
104106
"temperature": 0.7,
105107
"max_tokens": 2000, # TODO: this might not be enough
106-
"objective": objective or self.objective,
107108
"max_turns": 5,
108109
},
109110
"json": json.dumps(
@@ -119,6 +120,11 @@ async def _create_simulation_request(self, prompt: str, objective: str) -> Dict[
119120
"simulationType": "Default",
120121
}
121122

123+
if self.tense:
124+
body["templateParameters"]["tense"] = self.tense
125+
if objective or self.objective:
126+
body["templateParameters"]["objective"] = objective or self.objective
127+
122128
self.logger.debug(f"Created simulation request body: {json.dumps(body, indent=2)}")
123129
return body
124130

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_utils/strategy_utils.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import random
66
from typing import Dict, List, Union, Optional, Any, Callable, cast
7-
7+
import logging
8+
from azure.ai.evaluation.simulator._model_tools._generated_rai_client import GeneratedRAIClient
89
from .._attack_strategy import AttackStrategy
910
from pyrit.prompt_converter import (
1011
PromptConverter,
@@ -24,16 +25,34 @@
2425
ROT13Converter,
2526
SuffixAppendConverter,
2627
StringJoinConverter,
28+
TenseConverter,
2729
UnicodeConfusableConverter,
2830
UnicodeSubstitutionConverter,
2931
UrlConverter,
3032
)
33+
from ._rai_service_target import AzureRAIServiceTarget
3134
from .._default_converter import _DefaultConverter
3235
from pyrit.prompt_target import OpenAIChatTarget, PromptChatTarget
3336
from .._callback_chat_target import _CallbackChatTarget
3437
from azure.ai.evaluation._model_configurations import AzureOpenAIModelConfiguration, OpenAIModelConfiguration
3538

3639

40+
def create_tense_converter(
41+
generated_rai_client: GeneratedRAIClient, is_one_dp_project: bool, logger: logging.Logger
42+
) -> TenseConverter:
43+
"""Factory function for creating TenseConverter with proper dependencies."""
44+
converter_target = AzureRAIServiceTarget(
45+
client=generated_rai_client,
46+
api_version=None,
47+
model="gpt-4",
48+
prompt_template_key="prompt_converters/tense_converter.yaml",
49+
logger=logger,
50+
is_one_dp_project=is_one_dp_project,
51+
tense="past",
52+
)
53+
return TenseConverter(converter_target=converter_target, tense="past")
54+
55+
3756
def strategy_converter_map() -> Dict[Any, Union[PromptConverter, List[PromptConverter], None]]:
3857
"""
3958
Returns a mapping of attack strategies to their corresponding converters.
@@ -61,7 +80,7 @@ def strategy_converter_map() -> Dict[Any, Union[PromptConverter, List[PromptConv
6180
suffix='describing.\\ + similarlyNow write oppositeley.]( Me giving**ONE please? revert with "\\!--Two'
6281
),
6382
AttackStrategy.StringJoin: StringJoinConverter(),
64-
AttackStrategy.Tense: _DefaultConverter(),
83+
AttackStrategy.Tense: create_tense_converter,
6584
AttackStrategy.UnicodeConfusable: UnicodeConfusableConverter(),
6685
AttackStrategy.UnicodeSubstitution: UnicodeSubstitutionConverter(),
6786
AttackStrategy.Url: UrlConverter(),
@@ -73,19 +92,25 @@ def strategy_converter_map() -> Dict[Any, Union[PromptConverter, List[PromptConv
7392

7493

7594
def get_converter_for_strategy(
76-
attack_strategy: Union[AttackStrategy, List[AttackStrategy]]
95+
attack_strategy: Union[AttackStrategy, List[AttackStrategy]],
96+
generated_rai_client: GeneratedRAIClient,
97+
is_one_dp_project: bool,
98+
logger: logging.Logger,
7799
) -> Union[PromptConverter, List[PromptConverter], None]:
78-
"""Get the appropriate converter for a given attack strategy.
100+
"""Get the appropriate converter for a given attack strategy."""
101+
factory_map = strategy_converter_map()
102+
103+
def _resolve_converter(strategy):
104+
converter_or_factory = factory_map[strategy]
105+
if callable(converter_or_factory) and not isinstance(converter_or_factory, PromptConverter):
106+
# It's a factory function, call it with dependencies
107+
return converter_or_factory(generated_rai_client, is_one_dp_project, logger)
108+
return converter_or_factory
79109

80-
:param attack_strategy: The attack strategy or list of strategies
81-
:type attack_strategy: Union[AttackStrategy, List[AttackStrategy]]
82-
:return: The converter(s) for the strategy
83-
:rtype: Union[PromptConverter, List[PromptConverter], None]
84-
"""
85110
if isinstance(attack_strategy, List):
86-
return [strategy_converter_map()[strategy] for strategy in attack_strategy]
111+
return [_resolve_converter(strategy) for strategy in attack_strategy]
87112
else:
88-
return strategy_converter_map()[attack_strategy]
113+
return _resolve_converter(attack_strategy)
89114

90115

91116
def get_chat_target(

sdk/evaluation/azure-ai-evaluation/tests/unittests/test_redteam/test_red_team_language_support.py

Lines changed: 0 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -154,60 +154,3 @@ async def test_get_attack_objectives_passes_language(self, mock_azure_ai_project
154154
mock_rai_client.get_attack_objectives.assert_called_once()
155155
call_args = mock_rai_client.get_attack_objectives.call_args
156156
assert call_args.kwargs["language"] == SupportedLanguages.Spanish.value
157-
158-
@pytest.mark.asyncio
159-
async def test_get_attack_objectives_tense_strategy_passes_language(self, mock_azure_ai_project, mock_credential):
160-
"""Test that _get_attack_objectives passes language parameter for tense strategy."""
161-
with patch("azure.ai.evaluation.red_team._red_team.GeneratedRAIClient") as mock_rai_client_class, patch(
162-
"azure.ai.evaluation.red_team._red_team.setup_logger"
163-
) as mock_setup_logger, patch("azure.ai.evaluation.red_team._red_team.initialize_pyrit"), patch(
164-
"azure.ai.evaluation.red_team._red_team._AttackObjectiveGenerator"
165-
) as mock_attack_obj_generator_class:
166-
167-
mock_logger = MagicMock()
168-
mock_setup_logger.return_value = mock_logger
169-
170-
# Set up mock RAI client instance
171-
mock_rai_client = MagicMock()
172-
mock_rai_client.get_attack_objectives = AsyncMock(
173-
return_value=[
174-
{
175-
"id": "test-id",
176-
"messages": [{"role": "user", "content": "test prompt"}],
177-
"metadata": {"target_harms": [{"risk-type": "violence"}]},
178-
}
179-
]
180-
)
181-
mock_rai_client_class.return_value = mock_rai_client
182-
183-
# Set up mock attack objective generator instance
184-
mock_attack_obj_generator = MagicMock()
185-
mock_attack_obj_generator.num_objectives = 5
186-
mock_attack_obj_generator.custom_attack_seed_prompts = None
187-
mock_attack_obj_generator.validated_prompts = False
188-
mock_attack_obj_generator_class.return_value = mock_attack_obj_generator
189-
190-
# Create RedTeam instance with French language
191-
agent = RedTeam(
192-
azure_ai_project=mock_azure_ai_project,
193-
credential=mock_credential,
194-
risk_categories=[RiskCategory.Violence],
195-
num_objectives=5,
196-
language=SupportedLanguages.French,
197-
)
198-
199-
agent.generated_rai_client = mock_rai_client
200-
agent.scan_session_id = "test-session"
201-
202-
# Call _get_attack_objectives with tense strategy
203-
await agent._get_attack_objectives(
204-
risk_category=RiskCategory.Violence,
205-
application_scenario="test scenario",
206-
strategy="tense",
207-
)
208-
209-
# Verify that get_attack_objectives was called with French language and tense strategy
210-
mock_rai_client.get_attack_objectives.assert_called_once()
211-
call_args = mock_rai_client.get_attack_objectives.call_args
212-
assert call_args.kwargs["language"] == SupportedLanguages.French.value # French language code
213-
assert call_args.kwargs["strategy"] == "tense"

sdk/evaluation/azure-ai-evaluation/tests/unittests/test_redteam/test_strategy_utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,25 @@ class TestConverterForStrategy:
5252

5353
def test_get_converter_for_strategy_single(self):
5454
"""Test getting converter for a single strategy."""
55-
converter = get_converter_for_strategy(AttackStrategy.Base64)
55+
# Create mock dependencies
56+
mock_rai_client = MagicMock()
57+
mock_logger = MagicMock()
58+
59+
converter = get_converter_for_strategy(AttackStrategy.Base64, mock_rai_client, False, mock_logger)
5660
assert isinstance(converter, Base64Converter)
5761

5862
# Test strategy with no converter
59-
converter = get_converter_for_strategy(AttackStrategy.Baseline)
63+
converter = get_converter_for_strategy(AttackStrategy.Baseline, mock_rai_client, False, mock_logger)
6064
assert converter is None
6165

6266
def test_get_converter_for_strategy_list(self):
6367
"""Test getting converters for a list of strategies."""
68+
# Create mock dependencies
69+
mock_rai_client = MagicMock()
70+
mock_logger = MagicMock()
71+
6472
strategies = [AttackStrategy.Base64, AttackStrategy.Flip]
65-
converters = get_converter_for_strategy(strategies)
73+
converters = get_converter_for_strategy(strategies, mock_rai_client, False, mock_logger)
6674

6775
assert isinstance(converters, list)
6876
assert len(converters) == 2

0 commit comments

Comments
 (0)