Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions pyrit/scenario/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Dataset classes for scenario data loading."""

from pyrit.scenario.dataset.load_utils import ScenarioDatasetUtils

__all__ = [
"ScenarioDatasetUtils"
]
36 changes: 36 additions & 0 deletions pyrit/scenario/dataset/load_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from pathlib import Path
from typing import List
from pyrit.models import SeedDataset
from pyrit.common.path import DATASETS_PATH, SCORER_CONFIG_PATH
from pyrit.datasets.harmbench_dataset import fetch_harmbench_dataset


class ScenarioDatasetUtils:
"""
Set of dataset loading utilities for Scenario class.
"""
@classmethod
def seed_dataset_to_list_str(cls, dataset: Path) -> List[str]:
seed_prompts: List[str] = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm kind of surprised we're using these as plain strings. It loses all the metadata. That means we lose harm categories, for example. How will one query for the results?

seed_prompts.extend(SeedDataset.from_yaml_file(dataset).get_values())
return seed_prompts

@classmethod
def get_seed_dataset(cls, which: str) -> SeedDataset:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is not a common parameter naming choice. Name seems preferable.

"""
Get SeedDataset from shorthand string.
Args:
which (str): Which SeedDataset.
Returns:
SeedDataset: Desired dataset.
Raises:
ValueError: If dataset not found.
"""
match which:
case "harmbench":
return fetch_harmbench_dataset()
case _:
raise ValueError(f"Error: unknown dataset `{which}` provided.")
8 changes: 4 additions & 4 deletions pyrit/scenario/scenarios/airt/cyber_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ScenarioCompositeStrategy,
ScenarioStrategy,
)
from pyrit.scenario.dataset import ScenarioDatasetUtils
from pyrit.score import (
SelfAskRefusalScorer,
SelfAskTrueFalseScorer,
Expand Down Expand Up @@ -172,10 +173,9 @@ def _get_default_dataset(self) -> list[str]:
Returns:
list[str]: List of seed prompt strings to be encoded and tested.
"""
seed_prompts: List[str] = []
malware_path = pathlib.Path(DATASETS_PATH) / "seed_prompts"
seed_prompts.extend(SeedDataset.from_yaml_file(malware_path / "malware.prompt").get_values())
return seed_prompts
return ScenarioDatasetUtils.seed_dataset_to_list_str(
pathlib.Path(DATASETS_PATH) / "seed_prompts" / "malware.prompt"
)

async def _get_atomic_attack_from_strategy_async(self, strategy: str) -> AtomicAttack:
"""
Expand Down
16 changes: 9 additions & 7 deletions pyrit/scenario/scenarios/encoding_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
AttackScoringConfig,
)
from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack
from pyrit.models import SeedDataset, SeedGroup
from pyrit.models import SeedGroup
from pyrit.models.seed_prompt import SeedPrompt
from pyrit.prompt_converter import (
AsciiSmugglerConverter,
Expand All @@ -39,6 +39,7 @@
ScenarioCompositeStrategy,
ScenarioStrategy,
)
from pyrit.scenario.dataset import ScenarioDatasetUtils
from pyrit.score import TrueFalseScorer
from pyrit.score.true_false.decoding_scorer import DecodingScorer

Expand Down Expand Up @@ -171,12 +172,13 @@ def _get_default_dataset(self) -> list[str]:
Returns:
list[str]: List of seed prompt strings to be encoded and tested.
"""
seed_prompts: list[str] = []
garak_path = pathlib.Path(DATASETS_PATH) / "seed_prompts" / "garak"
seed_prompts.extend(SeedDataset.from_yaml_file(garak_path / "slur_terms_en.prompt").get_values())
seed_prompts.extend(SeedDataset.from_yaml_file(garak_path / "web_html_js.prompt").get_values())

return seed_prompts
garak_slurs = ScenarioDatasetUtils.seed_dataset_to_list_str(
pathlib.Path(DATASETS_PATH) / "seed_prompts" / "garak" / "slur_terms_en.prompt"
)
garak_html = ScenarioDatasetUtils.seed_dataset_to_list_str(
pathlib.Path(DATASETS_PATH) / "seed_prompts" / "garak" / "web_html_js.prompt"
)
return garak_slurs + garak_html

async def _get_atomic_attacks_async(self) -> List[AtomicAttack]:
"""
Expand Down
4 changes: 2 additions & 2 deletions pyrit/scenario/scenarios/foundry_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from typing import List, Optional, Sequence, Type, TypeVar

from pyrit.common import apply_defaults
from pyrit.datasets.harmbench_dataset import fetch_harmbench_dataset
from pyrit.datasets.text_jailbreak import TextJailBreak
from pyrit.executor.attack.core.attack_config import (
AttackAdversarialConfig,
Expand Down Expand Up @@ -62,6 +61,7 @@
ScenarioCompositeStrategy,
ScenarioStrategy,
)
from pyrit.scenario.dataset import ScenarioDatasetUtils
from pyrit.score import (
AzureContentFilterScorer,
FloatScaleThresholdScorer,
Expand Down Expand Up @@ -265,7 +265,7 @@ def __init__(
objectives
if objectives
else list(
fetch_harmbench_dataset().get_random_values(
ScenarioDatasetUtils.get_seed_dataset("harmbench").get_random_values(
number=4, harm_categories=["harmful", "harassment_bullying"]
)
)
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/scenarios/test_dataset_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

"""Tests for the scenarios.ScenarioDatasetUtils class."""

Loading