Skip to content

Commit bc802d1

Browse files
Ivan Evtimovfacebook-github-bot
authored andcommitted
Skip tasks with insufficient samples for pass@k instead of raising ex… (#113)
Summary: When computing pass@k metrics with k > 1, tasks that have fewer than k samples are now gracefully skipped with a warning log message instead of raising a ValueError that would terminate the entire results processing. Changes: - Replace ValueError with warning log when n_samples < k for a task - Add logging module import and logger instance - Collect skipped groups and log them with full context (dataset, agent, attack, task_id, and sample count) - Add check for empty DataFrame after filtering in aggregate_results - Update docstrings to reflect new behavior (Note instead of Raises) - Also includes refactoring: remove job_name from group_cols to allow aggregating across multiple runs of the same experiment - Add generic variant_name support alongside legacy template_short_name Differential Revision: D92393526 Pulled By: evtimovi
1 parent 9389c9a commit bc802d1

File tree

1 file changed

+36
-14
lines changed

1 file changed

+36
-14
lines changed

src/prompt_siren/results.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import itertools
1111
import json
12+
import logging
1213
import sys
1314
from enum import auto
1415
from pathlib import Path
@@ -51,6 +52,8 @@ class Format(StrEnum):
5152
# Note: job_name is not included to allow grouping across jobs with the same agent/attack config
5253
_ALL_GROUP_COLS = ["dataset", "agent_type", "agent_name", "attack_type"]
5354

55+
logger = logging.getLogger(__name__)
56+
5457

5558
def estimate_pass_at_k(num_samples: int | list[int], num_correct: list[int], k: int) -> np.ndarray:
5659
"""Estimates pass@k of each problem and returns them in an array.
@@ -117,11 +120,12 @@ def _parse_index_entry(line: str, job_config: JobConfig) -> dict[str, Any]:
117120
row["attack_type"] = attack_type
118121
row["attack_config"] = attack_config
119122

120-
# For template_string attacks, append the template_short_name
121-
if attack_type == "template_string" and attack_config:
122-
template_short_name = attack_config.get("template_short_name")
123-
if template_short_name:
124-
row["attack_type"] = f"template_string_{template_short_name}"
123+
# Check for variant_name (generic) or template_short_name (legacy for template_string)
124+
variant_name = attack_config.get("variant_name") if attack_config else None
125+
if not variant_name and attack_type == "template_string" and attack_config:
126+
variant_name = attack_config.get("template_short_name")
127+
if variant_name:
128+
row["attack_type"] = f"{attack_type}_{variant_name}"
125129
else:
126130
row["attack_type"] = "benign"
127131
row["attack_config"] = None
@@ -225,8 +229,9 @@ def _group_by_task(df: pd.DataFrame, k: int = 1) -> pd.DataFrame:
225229
return df
226230

227231
# Group by configuration and task
228-
# Include dataset_suite and job_name to disambiguate tasks from different jobs
229-
group_cols = [*_ALL_GROUP_COLS, "dataset_suite", "job_name", "task_id"]
232+
# Note: job_name is NOT included to allow aggregating across multiple runs
233+
# of the same experiment (e.g., for pass@k computation)
234+
group_cols = [*_ALL_GROUP_COLS, "dataset_suite", "task_id"]
230235

231236
if k == 1:
232237
# Original behavior: average across timestamps
@@ -241,15 +246,17 @@ def _group_by_task(df: pd.DataFrame, k: int = 1) -> pd.DataFrame:
241246

242247
# For k > 1: compute pass@k metric
243248
results = []
249+
skipped_groups: list[dict[str, Any]] = []
244250
for group_key, group in df.groupby(group_cols):
245251
n_samples = len(group)
246252

247-
# Error if we don't have enough samples
253+
# Skip groups with insufficient samples and log a warning
248254
if n_samples < k:
249-
task_id = group["task_id"].iloc[0]
250-
raise ValueError(
251-
f"Task '{task_id}' has only {n_samples} samples but k={k}. Need at least k samples to compute pass@{k}."
252-
)
255+
key_tuple = group_key if isinstance(group_key, tuple) else (group_key,)
256+
group_info = dict(zip(group_cols, key_tuple, strict=True))
257+
group_info["n_samples"] = n_samples
258+
skipped_groups.append(group_info)
259+
continue
253260

254261
# Count number of correct samples (score = 1.0)
255262
n_benign_correct = (group["benign_score"] == 1.0).sum()
@@ -272,6 +279,17 @@ def _group_by_task(df: pd.DataFrame, k: int = 1) -> pd.DataFrame:
272279
result_row["n_samples"] = n_samples
273280
results.append(result_row)
274281

282+
# Log warnings for skipped groups
283+
if skipped_groups:
284+
for group_info in skipped_groups:
285+
n_samples = group_info.pop("n_samples")
286+
# Format group identifiers: dataset, agent_type, agent_name, attack_type, task_id
287+
group_str = ", ".join(f"{key}={value}" for key, value in group_info.items())
288+
logger.warning(
289+
f"Skipping group ({group_str}): has only {n_samples} samples but k={k}. "
290+
f"Need at least k samples to compute pass@{k}."
291+
)
292+
275293
return pd.DataFrame(results)
276294

277295

@@ -306,8 +324,8 @@ def aggregate_results(
306324
benign_pass@k, attack_pass@k, n_tasks, avg_n_samples
307325
(aggregates across dataset_suite's and job_name variations)
308326
309-
Raises:
310-
ValueError: If any task has fewer than k samples when k > 1
327+
Note:
328+
Groups with fewer than k samples are excluded and a warning is logged.
311329
"""
312330
# Convert single k to list for uniform handling
313331
k_values = [k] if isinstance(k, int) else k
@@ -341,6 +359,10 @@ def aggregate_results(
341359
# Stage 1: Always group by task (computing pass@k)
342360
df = _group_by_task(df, k=k_value)
343361

362+
# If all groups were filtered out due to insufficient samples, return empty DataFrame
363+
if df.empty:
364+
return pd.DataFrame()
365+
344366
# Determine score column names based on k
345367
benign_col = f"benign_pass@{k_value}"
346368
attack_col = f"attack_pass@{k_value}"

0 commit comments

Comments
 (0)