From e3f23b5d5e1daea22e7ba0c745e49d8c3006b94f Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 21 Nov 2025 23:26:49 +0000 Subject: [PATCH 1/2] Basic implementation added --- pyrit/scenario/dataset/__init__.py | 10 ++++++ pyrit/scenario/dataset/load_utils.py | 36 +++++++++++++++++++ .../scenario/scenarios/airt/cyber_scenario.py | 8 ++--- pyrit/scenario/scenarios/encoding_scenario.py | 16 +++++---- pyrit/scenario/scenarios/foundry_scenario.py | 10 ++++-- tests/unit/scenarios/test_dataset_utils.py | 5 +++ 6 files changed, 72 insertions(+), 13 deletions(-) create mode 100644 pyrit/scenario/dataset/__init__.py create mode 100644 pyrit/scenario/dataset/load_utils.py create mode 100644 tests/unit/scenarios/test_dataset_utils.py diff --git a/pyrit/scenario/dataset/__init__.py b/pyrit/scenario/dataset/__init__.py new file mode 100644 index 000000000..9c489e1e5 --- /dev/null +++ b/pyrit/scenario/dataset/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Dataset classes for scenario data loading.""" + +from pyrit.scenario.dataset.load_utils import ScenarioDatasetUtils + +__all__ = [ + "ScenarioDatasetUtils" +] diff --git a/pyrit/scenario/dataset/load_utils.py b/pyrit/scenario/dataset/load_utils.py new file mode 100644 index 000000000..096e8543c --- /dev/null +++ b/pyrit/scenario/dataset/load_utils.py @@ -0,0 +1,36 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from pathlib import Path +from typing import List +from pyrit.models import SeedDataset +from pyrit.common.path import DATASETS_PATH, SCORER_CONFIG_PATH +from pyrit.datasets.harmbench_dataset import fetch_harmbench_dataset + + +class ScenarioDatasetUtils: + """ + Set of dataset loading utilities for Scenario class. + """ + @classmethod + def seed_dataset_to_list_str(cls, dataset: Path) -> List[str]: + seed_prompts: List[str] = [] + seed_prompts.extend(SeedDataset.from_yaml_file(dataset).get_values()) + return seed_prompts + + @classmethod + def get_seed_dataset(cls, which: str) -> SeedDataset: + """ + Get SeedDataset from shorthand string. + Args: + which (str): Which SeedDataset. + Returns: + SeedDataset: Desired dataset. + Raises: + ValueError: If dataset not found. + """ + match which: + case "harmbench": + return fetch_harmbench_dataset() + case _: + raise ValueError(f"Error: unknown dataset `{which}` provided.") \ No newline at end of file diff --git a/pyrit/scenario/scenarios/airt/cyber_scenario.py b/pyrit/scenario/scenarios/airt/cyber_scenario.py index 924f94941..d289502e6 100644 --- a/pyrit/scenario/scenarios/airt/cyber_scenario.py +++ b/pyrit/scenario/scenarios/airt/cyber_scenario.py @@ -22,6 +22,7 @@ ScenarioCompositeStrategy, ScenarioStrategy, ) +from pyrit.scenario.dataset import ScenarioDatasetUtils from pyrit.score import ( SelfAskRefusalScorer, SelfAskTrueFalseScorer, @@ -172,10 +173,9 @@ def _get_default_dataset(self) -> list[str]: Returns: list[str]: List of seed prompt strings to be encoded and tested. """ - seed_prompts: List[str] = [] - malware_path = pathlib.Path(DATASETS_PATH) / "seed_prompts" - seed_prompts.extend(SeedDataset.from_yaml_file(malware_path / "malware.prompt").get_values()) - return seed_prompts + return ScenarioDatasetUtils.seed_dataset_to_list_str( + pathlib.Path(DATASETS_PATH) / "seed_prompts" / "malware.prompt" + ) async def _get_atomic_attack_from_strategy_async(self, strategy: str) -> AtomicAttack: """ diff --git a/pyrit/scenario/scenarios/encoding_scenario.py b/pyrit/scenario/scenarios/encoding_scenario.py index f7540d940..cc210da07 100644 --- a/pyrit/scenario/scenarios/encoding_scenario.py +++ b/pyrit/scenario/scenarios/encoding_scenario.py @@ -12,7 +12,7 @@ AttackScoringConfig, ) from pyrit.executor.attack.single_turn.prompt_sending import PromptSendingAttack -from pyrit.models import SeedDataset, SeedGroup +from pyrit.models import SeedGroup from pyrit.models.seed_prompt import SeedPrompt from pyrit.prompt_converter import ( AsciiSmugglerConverter, @@ -39,6 +39,7 @@ ScenarioCompositeStrategy, ScenarioStrategy, ) +from pyrit.scenario.dataset import ScenarioDatasetUtils from pyrit.score import TrueFalseScorer from pyrit.score.true_false.decoding_scorer import DecodingScorer @@ -171,12 +172,13 @@ def _get_default_dataset(self) -> list[str]: Returns: list[str]: List of seed prompt strings to be encoded and tested. """ - seed_prompts: list[str] = [] - garak_path = pathlib.Path(DATASETS_PATH) / "seed_prompts" / "garak" - seed_prompts.extend(SeedDataset.from_yaml_file(garak_path / "slur_terms_en.prompt").get_values()) - seed_prompts.extend(SeedDataset.from_yaml_file(garak_path / "web_html_js.prompt").get_values()) - - return seed_prompts + garak_slurs = ScenarioDatasetUtils.seed_dataset_to_list_str( + pathlib.Path(DATASETS_PATH) / "seed_prompts" / "garak" / "slur_terms_en.prompt" + ) + garak_html = ScenarioDatasetUtils.seed_dataset_to_list_str( + pathlib.Path(DATASETS_PATH) / "seed_prompts" / "garak" / "web_html_js.prompt" + ) + return garak_slurs + garak_html async def _get_atomic_attacks_async(self) -> List[AtomicAttack]: """ diff --git a/pyrit/scenario/scenarios/foundry_scenario.py b/pyrit/scenario/scenarios/foundry_scenario.py index 4e9b34e33..2a405e571 100644 --- a/pyrit/scenario/scenarios/foundry_scenario.py +++ b/pyrit/scenario/scenarios/foundry_scenario.py @@ -14,7 +14,6 @@ from typing import List, Optional, Sequence, Type, TypeVar from pyrit.common import apply_defaults -from pyrit.datasets.harmbench_dataset import fetch_harmbench_dataset from pyrit.datasets.text_jailbreak import TextJailBreak from pyrit.executor.attack.core.attack_config import ( AttackAdversarialConfig, @@ -62,6 +61,7 @@ ScenarioCompositeStrategy, ScenarioStrategy, ) +from pyrit.scenario.dataset import ScenarioDatasetUtils from pyrit.score import ( AzureContentFilterScorer, FloatScaleThresholdScorer, @@ -265,7 +265,13 @@ def __init__( objectives if objectives else list( - fetch_harmbench_dataset().get_random_values( + ScenarioDatasetUtils.get_seed_dataset("harmbench").get_random_values( + number=4, harm_categories=["harmful", "harassment_bullying"] + ) + ) + + else list( + .get_random_values( number=4, harm_categories=["harmful", "harassment_bullying"] ) ) diff --git a/tests/unit/scenarios/test_dataset_utils.py b/tests/unit/scenarios/test_dataset_utils.py new file mode 100644 index 000000000..208c086f5 --- /dev/null +++ b/tests/unit/scenarios/test_dataset_utils.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Tests for the scenarios.ScenarioDatasetUtils class.""" + From ee8927fcb9197b9fc77a4cf510ec53623f4d18d6 Mon Sep 17 00:00:00 2001 From: Victor Valbuena Date: Fri, 21 Nov 2025 23:31:16 +0000 Subject: [PATCH 2/2] Duplicate else clause removed --- pyrit/scenario/scenarios/foundry_scenario.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/pyrit/scenario/scenarios/foundry_scenario.py b/pyrit/scenario/scenarios/foundry_scenario.py index 2a405e571..bf33ec255 100644 --- a/pyrit/scenario/scenarios/foundry_scenario.py +++ b/pyrit/scenario/scenarios/foundry_scenario.py @@ -269,12 +269,6 @@ def __init__( number=4, harm_categories=["harmful", "harassment_bullying"] ) ) - - else list( - .get_random_values( - number=4, harm_categories=["harmful", "harassment_bullying"] - ) - ) ) super().__init__(