Skip to content
5 changes: 2 additions & 3 deletions pyrit/datasets/adv_bench_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from pyrit.common.path import DATASETS_PATH
from pyrit.datasets.dataset_helper import fetch_examples
from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt
from pyrit.models import HarmCategory, SeedPrompt, SeedPromptDataset


def fetch_adv_bench_dataset(
Expand Down Expand Up @@ -87,7 +86,7 @@ def fetch_adv_bench_dataset(
data_type="text",
name="AdvBench Dataset [Extended]",
dataset_name="AdvBench Dataset",
harm_categories=item["main_categories"] + item["sub_categories"],
harm_categories=[HarmCategory.parse(cat) for cat in item["main_categories"] + item["sub_categories"]],
description="""AdvBench is a set of 520 harmful behaviors formulated as instructions. This dataset
has been extended to include harm categories for better filtering and analysis. The adversary's goal
is instead to find a single attack string that will cause the model to generate any response that
Expand Down
12 changes: 7 additions & 5 deletions pyrit/datasets/aya_redteaming_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from typing import List, Literal, Optional

from pyrit.datasets.dataset_helper import fetch_examples
from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt
from pyrit.models import HarmCategory, SeedPrompt, SeedPromptDataset


def fetch_aya_redteaming_dataset(
Expand Down Expand Up @@ -75,19 +74,22 @@ def fetch_aya_redteaming_dataset(
data_home=data_home,
)

parsed_filter_categories = [HarmCategory.parse(c) for c in harm_categories] if harm_categories else None

seed_prompts = []

for example in examples:
categories = ast.literal_eval(example["harm_category"])
if harm_categories is None or any(cat in categories for cat in harm_categories):
raw_categories = ast.literal_eval(example["harm_category"])
parsed_categories = [HarmCategory.parse(c) for c in raw_categories]
if parsed_filter_categories is None or any(cat in parsed_categories for cat in parsed_filter_categories):
if harm_scope is None or example["global_or_local"] == harm_scope:
seed_prompts.append(
SeedPrompt(
value=example["prompt"],
data_type="text",
name="Aya Red-teaming Examples",
dataset_name="Aya Red-teaming Examples",
harm_categories=categories,
harm_categories=parsed_categories,
source="https://huggingface.co/datasets/CohereForAI/aya_redteaming",
)
)
Expand Down
5 changes: 2 additions & 3 deletions pyrit/datasets/ccp_sensitive_prompts_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from datasets import load_dataset

from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt
from pyrit.models import HarmCategory, SeedPrompt, SeedPromptDataset


def fetch_ccp_sensitive_prompts_dataset() -> SeedPromptDataset:
Expand Down Expand Up @@ -32,7 +31,7 @@ def fetch_ccp_sensitive_prompts_dataset() -> SeedPromptDataset:
data_type="text",
name="",
dataset_name="CCP-sensitive-prompts",
harm_categories=[row["subject"]],
harm_categories=[HarmCategory.parse(row["subject"])],
description=("Prompts covering topics sensitive to the CCP."),
groups=["promptfoo"],
source="https://huggingface.co/datasets/promptfoo/CCP-sensitive-prompts",
Expand Down
5 changes: 2 additions & 3 deletions pyrit/datasets/darkbench_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from datasets import load_dataset

from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt
from pyrit.models import HarmCategory, SeedPrompt, SeedPromptDataset


def fetch_darkbench_dataset() -> SeedPromptDataset:
Expand All @@ -28,7 +27,7 @@ def fetch_darkbench_dataset() -> SeedPromptDataset:
data_type="text",
name="",
dataset_name="DarkBench",
harm_categories=[item["Deceptive Pattern"]],
harm_categories=[HarmCategory.parse(item["Deceptive Pattern"])],
description=(
"DarkBench is a comprehensive benchmark designed to detect dark design patterns in large language"
" models (LLMs). Dark patterns are manipulative techniques that influence user behavior, often against"
Expand Down
5 changes: 2 additions & 3 deletions pyrit/datasets/decoding_trust_stereotypes_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from pyrit.common.path import DATASETS_PATH
from pyrit.datasets.dataset_helper import FILE_TYPE_HANDLERS, fetch_examples
from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt
from pyrit.models import HarmCategory, SeedPrompt, SeedPromptDataset


def fetch_decoding_trust_stereotypes_dataset(
Expand Down Expand Up @@ -105,7 +104,7 @@ def fetch_decoding_trust_stereotypes_dataset(
data_type="text",
name="DecodingTrust Stereotypes Examples",
dataset_name="DecodingTrust Stereotypes Examples",
harm_categories=["stereotypes"],
harm_categories=[HarmCategory.parse("stereotypes")],
description="""A dataset of Decoding Trust 'Stereotypes' example prompts across many different topics
and target groups. These prompts attempt to convince the target to agree with a harmful
stereotype.""",
Expand Down
5 changes: 2 additions & 3 deletions pyrit/datasets/equitymedqa_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from datasets import load_dataset

from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt
from pyrit.models import HarmCategory, SeedPrompt, SeedPromptDataset

# These are the available subsets of the EquityMedQA dataset
DATA_SUBSETS: list[str] = [
Expand Down Expand Up @@ -89,7 +88,7 @@ def fetch_equitymedqa_dataset_unique_values(
name="katielink/EquityMedQA",
dataset_name="katielink/EquityMedQA",
description="This dataset contains prompts used to assess medical biases in AI systems",
harm_categories=["health_bias"],
harm_categories=[HarmCategory.parse("health_bias")],
source="https://huggingface.co/datasets/katielink/EquityMedQA",
)
for prompt in prompts
Expand Down
5 changes: 2 additions & 3 deletions pyrit/datasets/forbidden_questions_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from datasets import load_dataset

from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt
from pyrit.models import HarmCategory, SeedPrompt, SeedPromptDataset


def fetch_forbidden_questions_dataset() -> SeedPromptDataset:
Expand All @@ -28,7 +27,7 @@ def fetch_forbidden_questions_dataset() -> SeedPromptDataset:
name="TrustAIRLab/forbidden_question_set",
dataset_name="TrustAIRLab/forbidden_question_set",
authors=authors,
harm_categories=item["content_policy_name"],
harm_categories=[HarmCategory.parse(item["content_policy_name"])],
source="https://huggingface.co/datasets/TrustAIRLab/forbidden_question_set",
description="""This is the Forbidden Question Set dataset proposed in the ACM CCS 2024 paper
"Do Anything Now'': Characterizing and Evaluating In-The-Wild Jailbreak Prompts on Large Language Models.
Expand Down
6 changes: 2 additions & 4 deletions pyrit/datasets/multilingual_vulnerability_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@

import pandas as pd

from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt

from pyrit.models import HarmCategory, SeedPrompt, SeedPromptDataset

def fetch_multilingual_vulnerability_dataset() -> SeedPromptDataset:
"""
Expand All @@ -24,7 +22,7 @@ def fetch_multilingual_vulnerability_dataset() -> SeedPromptDataset:
data_type="text",
name=str(row["id"]),
dataset_name="Multilingual-Vulnerability",
harm_categories=[row["type"]],
harm_categories=[HarmCategory.parse(row["type"])],
description="Dataset from 'A Framework to Assess Multilingual Vulnerabilities of LLMs'. "
"Multilingual prompts demonstrating LLM vulnerabilities, labeled by type. "
"Paper: https://arxiv.org/pdf/2503.13081",
Expand Down
20 changes: 13 additions & 7 deletions pyrit/datasets/red_team_social_bias_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from datasets import load_dataset

from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt
from pyrit.models import HarmCategory, SeedPrompt, SeedPromptDataset


def fetch_red_team_social_bias_dataset() -> SeedPromptDataset:
Expand Down Expand Up @@ -60,14 +59,21 @@ def fetch_red_team_social_bias_dataset() -> SeedPromptDataset:
if prompt_type is None:
continue

raw_categories = item.get("categorization", [])
if isinstance(raw_categories, str):
raw_categories = [raw_categories]

harm_categories = []
for cat in raw_categories:
try:
harm_categories.append(HarmCategory.parse(cat))
except Exception:
harm_categories.append(HarmCategory.OTHER)

# Dictionary of metadata for the current prompt
prompt_metadata = {
**common_metadata,
"harm_categories": (
[item["categorization"]]
if not isinstance(item.get("categorization"), list)
else item.get("categorization", [])
),
"harm_categories": harm_categories,
"groups": [item.get("organization", "")],
"metadata": {
"prompt_type": prompt_type,
Expand Down
5 changes: 2 additions & 3 deletions pyrit/datasets/seclists_bias_testing_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
import pycountry

from pyrit.datasets.dataset_helper import FILE_TYPE_HANDLERS, fetch_examples
from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt
from pyrit.models import HarmCategory, SeedPrompt, SeedPromptDataset


def fetch_seclists_bias_testing_dataset(
Expand Down Expand Up @@ -95,7 +94,7 @@ def fetch_seclists_bias_testing_dataset(
data_type="text",
name="SecLists Bias Testing Examples",
dataset_name="SecLists Bias Testing Examples",
harm_categories=["bias_testing"],
harm_categories=[HarmCategory.REPRESENTATIONAL],
description="A dataset of SecLists AI LLM Bias Testing examples with placeholders replaced.",
)
for example in filled_examples
Expand Down
5 changes: 2 additions & 3 deletions pyrit/datasets/sosbench_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

from datasets import load_dataset

from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt
from pyrit.models import HarmCategory, SeedPrompt, SeedPromptDataset


def fetch_sosbench_dataset() -> SeedPromptDataset:
Expand All @@ -27,7 +26,7 @@ def fetch_sosbench_dataset() -> SeedPromptDataset:
data_type="text",
name="",
dataset_name="SOSBench",
harm_categories=[item["subject"]],
harm_categories=[HarmCategory.parse(item["subject"])],
description=(
"SOSBench is a regulation-grounded, hazard-focused benchmark encompassing "
"six high-risk scientific domains: chemistry, biology, medicine, pharmacology, "
Expand Down
37 changes: 21 additions & 16 deletions pyrit/datasets/xstest_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from typing import Literal, Optional

from pyrit.datasets.dataset_helper import FILE_TYPE_HANDLERS, fetch_examples
from pyrit.models import SeedPromptDataset
from pyrit.models.seed_prompt import SeedPrompt
from pyrit.models import HarmCategory, SeedPrompt, SeedPromptDataset


def fetch_xstest_dataset(
Expand Down Expand Up @@ -41,21 +40,27 @@ def fetch_xstest_dataset(
# Fetch the examples using the provided `fetch_examples` function
examples = fetch_examples(source, source_type, cache, data_home)

# Extract prompts, harm categories, and other relevant data from the fetched examples
prompts = [example["prompt"] for example in examples]
harm_categories = [example["note"] for example in examples]

seed_prompts = [
SeedPrompt(
value=example,
data_type="text",
name="XSTest Examples",
dataset_name="XSTest Examples",
harm_categories=harm_categories,
description="A dataset of XSTest examples containing various categories such as violence, drugs, etc.",
seed_prompts = []

for example in examples:
prompt_text = example["prompt"]
note = example.get("note", "")

try:
harm_category = HarmCategory.parse(note)
except Exception:
harm_category = HarmCategory.OTHER

seed_prompts.append(
SeedPrompt(
value=prompt_text,
data_type="text",
name="XSTest Examples",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove name or empty string

dataset_name="XSTest Examples",
harm_categories=[harm_category],
description="A dataset of XSTest examples containing various categories such as violence, drugs, etc.",
)
)
for example in prompts
]

seed_prompt_dataset = SeedPromptDataset(prompts=seed_prompts)

Expand Down
11 changes: 6 additions & 5 deletions pyrit/memory/memory_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
SeedPromptEntry,
)
from pyrit.models import (
AttackResult,
HarmCategory,
ChatMessage,
DataTypeSerializer,
PromptRequestPiece,
Expand All @@ -43,7 +45,6 @@
group_conversation_request_pieces_by_sequence,
sort_request_pieces,
)
from pyrit.models.attack_result import AttackResult

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -585,7 +586,7 @@ def get_seed_prompts(
value_sha256: Optional[Sequence[str]] = None,
dataset_name: Optional[str] = None,
data_types: Optional[Sequence[str]] = None,
harm_categories: Optional[Sequence[str]] = None,
harm_categories: Optional[Sequence[HarmCategory]] = None,
added_by: Optional[str] = None,
authors: Optional[Sequence[str]] = None,
groups: Optional[Sequence[str]] = None,
Expand All @@ -602,7 +603,7 @@ def get_seed_prompts(
dataset_name (str): The dataset name to match. If None, all dataset names are considered.
data_types (Optional[Sequence[str], Optional): List of data types to filter seed prompts by
(e.g., text, image_path).
harm_categories (Sequence[str]): A list of harm categories to filter by. If None,
harm_categories (Sequence[HarmCategory]): A list of harm categories to filter by. If None,
all harm categories are considered.
Specifying multiple harm categories returns only prompts that are marked with all harm categories.
added_by (str): The user who added the prompts.
Expand Down Expand Up @@ -794,7 +795,7 @@ def get_seed_prompt_groups(
value_sha256: Optional[Sequence[str]] = None,
dataset_name: Optional[str] = None,
data_types: Optional[Sequence[str]] = None,
harm_categories: Optional[Sequence[str]] = None,
harm_categories: Optional[Sequence[HarmCategory]] = None,
added_by: Optional[str] = None,
authors: Optional[Sequence[str]] = None,
groups: Optional[Sequence[str]] = None,
Expand All @@ -807,7 +808,7 @@ def get_seed_prompt_groups(
dataset_name (Optional[str], Optional): Name of the dataset to filter seed prompts.
data_types (Optional[Sequence[str]], Optional): List of data types to filter seed prompts by
(e.g., text, image_path).
harm_categories (Optional[Sequence[str]], Optional): List of harm categories to filter seed prompts by.
harm_categories (Optional[Sequence[HarmCategory]], Optional): List of harm categories to filter seed prompts by.
added_by (Optional[str], Optional): The user who added the seed prompt groups to filter by.
authors (Optional[Sequence[str]], Optional): List of authors to filter seed prompt groups by.
groups (Optional[Sequence[str]], Optional): List of groups to filter seed prompt groups by.
Expand Down
5 changes: 3 additions & 2 deletions pyrit/memory/memory_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from pyrit.common.utils import to_sha256
from pyrit.models import PromptDataType, PromptRequestPiece, Score, SeedPrompt
from pyrit.models.attack_result import AttackOutcome, AttackResult
from pyrit.models.harm_category import HarmCategory


class Base(DeclarativeBase):
Expand Down Expand Up @@ -281,7 +282,7 @@ class SeedPromptEntry(Base):
value_sha256 (str): The SHA256 hash of the value of the seed prompt data.
data_type (PromptDataType): The data type of the seed prompt.
dataset_name (str): The name of the dataset the seed prompt belongs to.
harm_categories (List[str]): The harm categories associated with the seed prompt.
harm_categories (List[HarmCategory]): The harm categories associated with the seed prompt.
description (str): The description of the seed prompt.
authors (List[str]): The authors of the seed prompt.
groups (List[str]): The groups involved in authoring the seed prompt (if any).
Expand Down Expand Up @@ -310,7 +311,7 @@ class SeedPromptEntry(Base):
data_type: Mapped[PromptDataType] = mapped_column(String, nullable=False)
name = mapped_column(String, nullable=True)
dataset_name = mapped_column(String, nullable=True)
harm_categories: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True)
harm_categories: Mapped[Optional[List[HarmCategory]]] = mapped_column(JSON, nullable=True)
description = mapped_column(String, nullable=True)
authors: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True)
groups: Mapped[Optional[List[str]]] = mapped_column(JSON, nullable=True)
Expand Down
2 changes: 2 additions & 0 deletions pyrit/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
data_serializer_factory,
)
from pyrit.models.embeddings import EmbeddingData, EmbeddingResponse, EmbeddingSupport, EmbeddingUsageInformation
from pyrit.models.harm_category import HarmCategory
from pyrit.models.identifiers import Identifier
from pyrit.models.literals import ChatMessageRole, PromptDataType, PromptResponseError
from pyrit.models.prompt_request_response import (
Expand Down Expand Up @@ -55,6 +56,7 @@
"EmbeddingUsageInformation",
"ErrorDataTypeSerializer",
"group_conversation_request_pieces_by_sequence",
"HarmCategory",
"Identifier",
"ImagePathDataTypeSerializer",
"sort_request_pieces",
Expand Down
Loading
Loading