diff --git a/pyrit/scenario/dataset/__init__.py b/pyrit/scenario/dataset/__init__.py new file mode 100644 index 000000000..9c489e1e5 --- /dev/null +++ b/pyrit/scenario/dataset/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Dataset classes for scenario data loading.""" + +from pyrit.scenario.dataset.load_utils import ScenarioDatasetUtils + +__all__ = [ + "ScenarioDatasetUtils" +] diff --git a/pyrit/scenario/dataset/load_utils.py b/pyrit/scenario/dataset/load_utils.py new file mode 100644 index 000000000..096e8543c --- /dev/null +++ b/pyrit/scenario/dataset/load_utils.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pathlib import Path +from typing import List +from pyrit.models import SeedDataset +from pyrit.common.path import DATASETS_PATH, SCORER_CONFIG_PATH +from pyrit.datasets.harmbench_dataset import fetch_harmbench_dataset + + +class ScenarioDatasetUtils: + """ + Set of dataset loading utilities for Scenario class. + """ + @classmethod + def seed_dataset_to_list_str(cls, dataset: Path) -> List[str]: + seed_prompts: List[str] = [] + seed_prompts.extend(SeedDataset.from_yaml_file(dataset).get_values()) + return seed_prompts + + @classmethod + def get_seed_dataset(cls, which: str) -> SeedDataset: + """ + Get SeedDataset from shorthand string. + Args: + which (str): Which SeedDataset. + Returns: + SeedDataset: Desired dataset. + Raises: + ValueError: If dataset not found. + """ + match which: + case "harmbench": + return fetch_harmbench_dataset() + case _: + raise ValueError(f"Error: unknown dataset `{which}` provided.") \ No newline at end of file diff --git a/pyrit/scenario/scenarios/airt/cyber_scenario.py b/pyrit/scenario/scenarios/airt/cyber_scenario.py index 924f94941..d289502e6 100644 --- a/pyrit/scenario/scenarios/airt/cyber_scenario.py +++ b/pyrit/scenario/scenarios/airt/cyber_scenario.py @@ -22,6 +22,7 @@ ScenarioCompositeStrategy, ScenarioStrategy, ) +from pyrit.scenario.dataset import ScenarioDatasetUtils from pyrit.score import ( SelfAskRefusalScorer, SelfAskTrueFalseScorer, @@ -172,10 +173,9 @@ def _get_default_dataset(self) -> list[str]: Returns: list[str]: List of seed prompt strings to be encoded and tested. """ - seed_prompts: List[str] = [] - malware_path = pathlib.Path(DATASETS_PATH) / "seed_prompts" - seed_prompts.extend(SeedDataset.from_yaml_file(malware_path / "malware.prompt").get_values()) - return seed_prompts + return ScenarioDatasetUtils.seed_dataset_to_list_str( + pathlib.Path(DATASETS_PATH) / "seed_prompts" / "malware.prompt" + ) async def _get_atomic_attack_from_strategy_async(self, strategy: str) -> AtomicAttack: """ diff --git a/pyrit/scenario/scenarios/encoding_scenario.py b/pyrit/scenario/scenarios/encoding_scenario.py index f7540d940..cc210da07 100644 --- a/pyrit/scenario/scenarios/encoding_scenario.py +++ b/pyrit/scenario/scenarios/encoding_scenario.py @@ -12,7 +12,7 @@ AttackScoringConfig, ) from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack -from pyrit.models import SeedDataset, SeedGroup +from pyrit.models import SeedGroup from pyrit.models.seed_prompt import SeedPrompt from pyrit.prompt_converter import ( AsciiSmugglerConverter, @@ -39,6 +39,7 @@ ScenarioCompositeStrategy, ScenarioStrategy, ) +from pyrit.scenario.dataset import ScenarioDatasetUtils from pyrit.score import TrueFalseScorer from pyrit.score.true_false.decoding_scorer import DecodingScorer @@ -171,12 +172,13 @@ def _get_default_dataset(self) -> list[str]: Returns: list[str]: List of seed prompt strings to be encoded and tested. """ - seed_prompts: list[str] = [] - garak_path = pathlib.Path(DATASETS_PATH) / "seed_prompts" / "garak" - seed_prompts.extend(SeedDataset.from_yaml_file(garak_path / "slur_terms_en.prompt").get_values()) - seed_prompts.extend(SeedDataset.from_yaml_file(garak_path / "web_html_js.prompt").get_values()) - - return seed_prompts + garak_slurs = ScenarioDatasetUtils.seed_dataset_to_list_str( + pathlib.Path(DATASETS_PATH) / "seed_prompts" / "garak" / "slur_terms_en.prompt" + ) + garak_html = ScenarioDatasetUtils.seed_dataset_to_list_str( + pathlib.Path(DATASETS_PATH) / "seed_prompts" / "garak" / "web_html_js.prompt" + ) + return garak_slurs + garak_html async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: """ diff --git a/pyrit/scenario/scenarios/foundry_scenario.py b/pyrit/scenario/scenarios/foundry_scenario.py index 4e9b34e33..bf33ec255 100644 --- a/pyrit/scenario/scenarios/foundry_scenario.py +++ b/pyrit/scenario/scenarios/foundry_scenario.py @@ -14,7 +14,6 @@ from typing import List, Optional, Sequence, Type, TypeVar from pyrit.common import apply_defaults -from pyrit.datasets.harmbench_dataset import fetch_harmbench_dataset from pyrit.datasets.text_jailbreak import TextJailBreak from pyrit.executor.attack.core.attack_config import ( AttackAdversarialConfig, @@ -62,6 +61,7 @@ ScenarioCompositeStrategy, ScenarioStrategy, ) +from pyrit.scenario.dataset import ScenarioDatasetUtils from pyrit.score import ( AzureContentFilterScorer, FloatScaleThresholdScorer, @@ -265,7 +265,7 @@ def __init__( objectives if objectives else list( - fetch_harmbench_dataset().get_random_values( + ScenarioDatasetUtils.get_seed_dataset("harmbench").get_random_values( number=4, harm_categories=["harmful", "harassment_bullying"] ) ) diff --git a/tests/unit/scenarios/test_dataset_utils.py b/tests/unit/scenarios/test_dataset_utils.py new file mode 100644 index 000000000..208c086f5 --- /dev/null +++ b/tests/unit/scenarios/test_dataset_utils.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for the scenarios.ScenarioDatasetUtils class.""" +