Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyrit/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
fetch_jbb_behaviors_by_harm_category,
fetch_jbb_behaviors_by_jbb_category,
)
from pyrit.datasets.fetch_jailbreakv_28k_dataset import fetch_jailbreakv_28k_dataset


__all__ = [
Expand Down Expand Up @@ -64,4 +65,5 @@
"fetch_jbb_behaviors_dataset",
"fetch_jbb_behaviors_by_harm_category",
"fetch_jbb_behaviors_by_jbb_category",
"fetch_jailbreakv_28k_dataset",
]
116 changes: 116 additions & 0 deletions pyrit/datasets/fetch_jailbreakv_28k_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import logging
from typing import List, Literal, Optional

from datasets import load_dataset

from pyrit.models import SeedPrompt, SeedPromptDataset

logger = logging.getLogger(__name__)

HarmLiteral = Literal[
"Unethical Behavior",
"Economic Harm",
"Hate Speech",
"Government Decision",
"Physical Harm",
"Fraud",
"Political Sensitivity",
"Malware",
"Illegal Activity",
"Bias",
"Violence",
"Animal Abuse",
"Tailored Unlicensed Advice",
"Privacy Violation",
"Health Consultation",
"Child Abuse Content",
]


def fetch_jailbreakv_28k_dataset(
*,
data_home: Optional[str] = None,
split: Literal["JailBreakV_28K", "mini_JailBreakV_28K"] = "mini_JailBreakV_28K",
text_field: Literal["jailbreak_query", "redteam_query"] = "redteam_query",
harm_categories: Optional[List[HarmLiteral]] = None,
) -> SeedPromptDataset:
"""
Fetch examples from the JailBreakV 28k Dataset with optional filtering and create a SeedPromptDataset.
Args:
data_home: Directory used as cache_dir in call to HF to store cached data. Defaults to None.
split (str): The split of the dataset to fetch. Defaults to "mini_JailBreakV_28K".
Options are "JailBreakV_28K" and "mini_JailBreakV_28K".
text_field (str): The field to use as the prompt text. Defaults to "redteam_query".
harm_categories: List of harm categories to filter the examples.
Defaults to None, which means all categories are included.
Otherwise, only prompts with at least one matching category are included.
Returns:
SeedPromptDataset: A SeedPromptDataset containing the filtered examples.
Note:
For more information and access to the original dataset and related materials, visit:
https://huggingface.co/datasets/JailbreakV-28K/JailBreakV-28k/blob/main/README.md \n
Related paper: https://arxiv.org/abs/2404.03027 \n
The dataset license: mit
Warning:
Due to the nature of these prompts, it may be advisable to consult your relevant legal
department before testing them with LLMs to ensure compliance and reduce potential risks.
"""

source = "JailbreakV-28K/JailBreakV-28k"

try:
logger.info(f"Loading JailBreakV-28k dataset from {source}")

# Normalize the harm categories to match pyrit harm category conventions
harm_categories_normalized = (
None if not harm_categories else [_normalize_policy(policy) for policy in harm_categories]
)

# Load the dataset from HuggingFace
data = load_dataset(source, "JailBreakV_28K", cache_dir=data_home)

dataset_split = data[split]

seed_prompts = []

# Define common metadata that will be used across all seed prompts
common_metadata = {
"dataset_name": "JailbreakV-28K",
"authors": ["Weidi Luo", "Siyuan Ma", "Xiaogeng Liu", "Chaowei Xiao"],
"description": (
"Benchmark for Assessing the Robustness of Large Language Models against Jailbreak Attacks. "
),
"source": source,
"data_type": "text",
"name": "JailBreakV-28K",
}

for item in dataset_split:
policy = _normalize_policy(item.get("policy", ""))
# Skip if user requested policy filter and items policy does not match
if harm_categories_normalized and policy not in harm_categories_normalized:
continue
seed_prompt = SeedPrompt(
value=item.get(text_field, ""),
harm_categories=[policy],
**common_metadata, # type: ignore[arg-type]
)
seed_prompts.append(seed_prompt)
seed_prompt_dataset = SeedPromptDataset(prompts=seed_prompts)
return seed_prompt_dataset

except Exception as e:
logger.error(f"Failed to load JailBreakV-28K dataset: {str(e)}")
raise Exception(f"Error loading JailBreakV-28K dataset: {str(e)}")


def _normalize_policy(policy: str) -> str:
"""Create a machine-friendly variant alongside the human-readable policy."""
return policy.strip().lower().replace(" ", "_").replace("-", "_")
57 changes: 57 additions & 0 deletions tests/unit/datasets/test_fetch_jailbreakv_28k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from unittest.mock import patch

import pytest

from pyrit.datasets.fetch_jailbreakv_28k_dataset import fetch_jailbreakv_28k_dataset
from pyrit.models import SeedPrompt, SeedPromptDataset


class TestFetchJailbreakv28kDataset:
"""Test suite for the fetch_jailbreakv_28k_dataset function."""

@pytest.mark.parametrize("text_field", [None, "jailbreak_query"])
@pytest.mark.parametrize("harm_categories", [None, ["Economic Harm"]])
@patch("pyrit.datasets.fetch_jailbreakv_28k_dataset.load_dataset")
def test_fetch_jailbreakv_28k_dataset_success(self, mock_load_dataset, text_field, harm_categories):
# Mock dataset response
mock_dataset = {
"mini_JailBreakV_28K": [
{
"redteam_query": "test query 1",
"jailbreak_query": "jailbreak: test query 1",
"policy": "Economic Harm",
},
{
"redteam_query": "test query 2",
"jailbreak_query": "jailbreak: test query 2",
"policy": "Government Decision",
},
{
"redteam_query": "test query 3",
"jailbreak_query": "jailbreak: test query 3",
"policy": "Fraud",
},
]
}
mock_load_dataset.return_value = mock_dataset

# Call the function
result = fetch_jailbreakv_28k_dataset(text_field=text_field, harm_categories=harm_categories)

# Assertions
assert isinstance(result, SeedPromptDataset)
if harm_categories is None:
assert len(result.prompts) == 3
elif harm_categories == ["Economic Harm"]:
assert len(result.prompts) == 1
print(result.prompts)
assert result.prompts[0].harm_categories == ["economic_harm"]
assert all(isinstance(prompt, SeedPrompt) for prompt in result.prompts)
print(result.prompts)
if text_field == "jailbreak_query":
assert all("jailbreak" in prompt.value for prompt in result.prompts)
else:
assert all("jailbreak" not in prompt.value for prompt in result.prompts)