Skip to content

Commit 585a302

Browse files
authored
redteam bug fixes for model usage computation and auth (#43633)
1 parent 69b0510 commit 585a302

File tree

2 files changed

+26
-10
lines changed

2 files changed

+26
-10
lines changed

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_red_team.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,7 @@ async def _get_attack_objectives(
325325
application_scenario: Optional[str] = None,
326326
strategy: Optional[str] = None,
327327
is_agent_target: Optional[bool] = None,
328+
client_id: Optional[str] = None,
328329
) -> List[str]:
329330
"""Get attack objectives from the RAI client for a specific risk category or from a custom dataset.
330331
@@ -407,6 +408,7 @@ async def _get_attack_objectives(
407408
current_key,
408409
num_objectives,
409410
is_agent_target,
411+
client_id,
410412
)
411413

412414
async def _get_custom_attack_objectives(
@@ -469,6 +471,7 @@ async def _get_rai_attack_objectives(
469471
current_key: tuple,
470472
num_objectives: int,
471473
is_agent_target: Optional[bool] = None,
474+
client_id: Optional[str] = None,
472475
) -> List[str]:
473476
"""Get attack objectives from the RAI service."""
474477
content_harm_risk = None
@@ -495,6 +498,7 @@ async def _get_rai_attack_objectives(
495498
language=self.language.value,
496499
scan_session_id=self.scan_session_id,
497500
target=target_type_str,
501+
client_id=client_id,
498502
)
499503
else:
500504
objectives_response = await self.generated_rai_client.get_attack_objectives(
@@ -505,6 +509,7 @@ async def _get_rai_attack_objectives(
505509
language=self.language.value,
506510
scan_session_id=self.scan_session_id,
507511
target=target_type_str,
512+
client_id=client_id,
508513
)
509514

510515
if isinstance(objectives_response, list):
@@ -539,6 +544,7 @@ async def _get_rai_attack_objectives(
539544
language=self.language.value,
540545
scan_session_id=self.scan_session_id,
541546
target="model",
547+
client_id=client_id,
542548
)
543549
else:
544550
objectives_response = await self.generated_rai_client.get_attack_objectives(
@@ -549,6 +555,7 @@ async def _get_rai_attack_objectives(
549555
language=self.language.value,
550556
scan_session_id=self.scan_session_id,
551557
target="model",
558+
client_id=client_id,
552559
)
553560

554561
if isinstance(objectives_response, list):
@@ -1022,6 +1029,8 @@ async def scan(
10221029
self._app_insights_configuration = _app_insights_configuration
10231030
self.taxonomy_risk_categories = taxonomy_risk_categories or {}
10241031
is_agent_target: Optional[bool] = kwargs.get("is_agent_target", False)
1032+
client_id: Optional[str] = kwargs.get("client_id")
1033+
10251034
with UserAgentSingleton().add_useragent_product(user_agent):
10261035
# Initialize scan
10271036
self._initialize_scan(scan_name, application_scenario)
@@ -1112,7 +1121,7 @@ async def scan(
11121121

11131122
# Fetch attack objectives
11141123
all_objectives = await self._fetch_all_objectives(
1115-
flattened_attack_strategies, application_scenario, is_agent_target
1124+
flattened_attack_strategies, application_scenario, is_agent_target, client_id
11161125
)
11171126

11181127
chat_target = get_chat_target(target)
@@ -1228,7 +1237,11 @@ def _initialize_tracking_dict(self, flattened_attack_strategies: List):
12281237
}
12291238

12301239
async def _fetch_all_objectives(
1231-
self, flattened_attack_strategies: List, application_scenario: str, is_agent_target: bool
1240+
self,
1241+
flattened_attack_strategies: List,
1242+
application_scenario: str,
1243+
is_agent_target: bool,
1244+
client_id: Optional[str] = None,
12321245
) -> Dict:
12331246
"""Fetch all attack objectives for all strategies and risk categories."""
12341247
log_section_header(self.logger, "Fetching attack objectives")
@@ -1242,6 +1255,7 @@ async def _fetch_all_objectives(
12421255
application_scenario=application_scenario,
12431256
strategy="baseline",
12441257
is_agent_target=is_agent_target,
1258+
client_id=client_id,
12451259
)
12461260
if "baseline" not in all_objectives:
12471261
all_objectives["baseline"] = {}
@@ -1266,6 +1280,7 @@ async def _fetch_all_objectives(
12661280
application_scenario=application_scenario,
12671281
strategy=strategy_name,
12681282
is_agent_target=is_agent_target,
1283+
client_id=client_id,
12691284
)
12701285
all_objectives[strategy_name][risk_category.value] = objectives
12711286

sdk/evaluation/azure-ai-evaluation/azure/ai/evaluation/red_team/_result_processor.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,7 +1231,6 @@ def _compute_per_model_usage(output_items: List[Dict[str, Any]]) -> List[Dict[st
12311231
"""
12321232
# Track usage by model name
12331233
model_usage: Dict[str, Dict[str, int]] = {}
1234-
12351234
for item in output_items:
12361235
if not isinstance(item, dict):
12371236
continue
@@ -1254,10 +1253,11 @@ def _compute_per_model_usage(output_items: List[Dict[str, Any]]) -> List[Dict[st
12541253
}
12551254

12561255
model_usage[model_name]["invocation_count"] += 1
1257-
model_usage[model_name]["prompt_tokens"] += usage.get("prompt_tokens", 0)
1258-
model_usage[model_name]["completion_tokens"] += usage.get("completion_tokens", 0)
1259-
model_usage[model_name]["total_tokens"] += usage.get("total_tokens", 0)
1260-
model_usage[model_name]["cached_tokens"] += usage.get("cached_tokens", 0)
1256+
# Convert to int to handle cases where values come as strings
1257+
model_usage[model_name]["prompt_tokens"] += int(usage.get("prompt_tokens", 0) or 0)
1258+
model_usage[model_name]["completion_tokens"] += int(usage.get("completion_tokens", 0) or 0)
1259+
model_usage[model_name]["total_tokens"] += int(usage.get("total_tokens", 0) or 0)
1260+
model_usage[model_name]["cached_tokens"] += int(usage.get("cached_tokens", 0) or 0)
12611261

12621262
# Always aggregate evaluator usage from results (separate from target usage)
12631263
results_list = item.get("results", [])
@@ -1286,9 +1286,10 @@ def _compute_per_model_usage(output_items: List[Dict[str, Any]]) -> List[Dict[st
12861286

12871287
if prompt_tokens or completion_tokens:
12881288
model_usage[model_name]["invocation_count"] += 1
1289-
model_usage[model_name]["prompt_tokens"] += prompt_tokens
1290-
model_usage[model_name]["completion_tokens"] += completion_tokens
1291-
model_usage[model_name]["total_tokens"] += prompt_tokens + completion_tokens
1289+
# Convert to int to handle cases where values come as strings
1290+
model_usage[model_name]["prompt_tokens"] += int(prompt_tokens or 0)
1291+
model_usage[model_name]["completion_tokens"] += int(completion_tokens or 0)
1292+
model_usage[model_name]["total_tokens"] += int(prompt_tokens or 0) + int(completion_tokens or 0)
12921293

12931294
if not model_usage:
12941295
return []

0 commit comments

Comments
 (0)