diff --git a/doc/_toc.yml b/doc/_toc.yml index a6239c2b9..037b5ad93 100644 --- a/doc/_toc.yml +++ b/doc/_toc.yml @@ -131,6 +131,10 @@ chapters: - file: code/scenarios/0_scenarios sections: - file: code/scenarios/1_configuring_scenarios + - file: code/registry/0_registry + sections: + - file: code/registry/1_class_registry + - file: code/registry/2_instance_registry - file: code/front_end/0_front_end sections: - file: code/front_end/1_pyrit_scan diff --git a/doc/code/registry/0_registry.md b/doc/code/registry/0_registry.md new file mode 100644 index 000000000..38786a3c3 --- /dev/null +++ b/doc/code/registry/0_registry.md @@ -0,0 +1,56 @@ +# Registry + +Registries in PyRIT provide a centralized way to discover, manage, and access components. They support lazy loading, singleton access, and metadata introspection. + +## Why Registries? + +- **Discovery**: Automatically find available components (scenarios, scorers, etc.) +- **Consistency**: Access components through a uniform API +- **Metadata**: Inspect what's available without instantiating everything +- **Extensibility**: Register custom components alongside built-in ones + +## Two Types of Registries + +PyRIT has two registry patterns for different use cases: + +| Type | Stores | Use Case | +|------|--------|----------| +| **Class Registry** | Classes (Type[T]) | Components instantiated with user-provided parameters | +| **Instance Registry** | Pre-configured instances | Components requiring complex setup before use | + +## Common API (RegistryProtocol) + +Both registry types implement `RegistryProtocol`, sharing a consistent interface: + +| Method | Description | +|--------|-------------| +| `get_registry_singleton()` | Get the singleton registry instance | +| `get_names()` | List all registered names | +| `list_metadata()` | Get descriptive metadata for all items | +| `reset_instance()` | Reset the singleton (useful for testing) | + +This protocol enables writing code that works with any registry type: + +```python +from pyrit.registry import RegistryProtocol + +def show_registry_contents(registry: RegistryProtocol) -> None: + for name in registry.get_names(): + print(name) +``` + + +## Key Difference with Class and Instance Registries + +| Aspect | Class Registry | Instance Registry | +|--------|----------------|-------------------| +| Stores | Classes (Type[T]) | Instances (T) | +| Registration | Automatic discovery | Explicit via `register_instance()` | +| Returns | Class to instantiate | Ready-to-use instance | +| Instantiation | Caller provides parameters | Pre-configured by initializer | +| When to use | Self-contained components with deferred configuration | Components requiring constructor parameters or compositional setup | + +## See Also + +- [Class Registries](1_class_registry.ipynb) - ScenarioRegistry, InitializerRegistry +- [Instance Registries](2_instance_registry.ipynb) - ScorerRegistry diff --git a/doc/code/registry/1_class_registry.ipynb b/doc/code/registry/1_class_registry.ipynb new file mode 100644 index 000000000..f3651034c --- /dev/null +++ b/doc/code/registry/1_class_registry.ipynb @@ -0,0 +1,256 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "## Listing Available Classes\n", + "\n", + "Use `get_names()` to see what's available, or `list_metadata()` for detailed information." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Available scenarios: ['content_harms', 'cyber', 'encoding', 'foundry', 'scam']...\n", + "\n", + "content_harms:\n", + " Class: ContentHarms\n", + " Description: Content Harms Scenario implementation for PyRIT. This scenario contains various ...\n", + "\n", + "cyber:\n", + " Class: Cyber\n", + " Description: Cyber scenario implementation for PyRIT. This scenario tests how willing models ...\n" + ] + } + ], + "source": [ + "from pyrit.registry import ScenarioRegistry\n", + "\n", + "registry = ScenarioRegistry.get_registry_singleton()\n", + "\n", + "# Get all registered names\n", + "names = registry.get_names()\n", + "print(f\"Available scenarios: {names[:5]}...\") # Show first 5\n", + "\n", + "# Get detailed metadata\n", + "metadata = registry.list_metadata()\n", + "for item in metadata[:2]: # Show first 2\n", + " print(f\"\\n{item.name}:\")\n", + " print(f\" Class: {item.class_name}\")\n", + " print(f\" Description: {item.description[:80]}...\")" + ] + }, + { + "cell_type": "markdown", + "id": "2", + "metadata": {}, + "source": [ + "## Getting a Class\n", + "\n", + "Use `get_class()` to retrieve a class by name. This returns the class itself, not an instance." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Got class: \n", + "Class name: Encoding\n" + ] + } + ], + "source": [ + "# Get a scenario class\n", + "\n", + "scenario_class = registry.get_class(\"encoding\")\n", + "\n", + "print(f\"Got class: {scenario_class}\")\n", + "print(f\"Class name: {scenario_class.__name__}\")" + ] + }, + { + "cell_type": "markdown", + "id": "4", + "metadata": {}, + "source": [ + "## Creating Instances\n", + "\n", + "Once you have a class, instantiate it with your parameters. You can also use `create_instance()` as a shortcut." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found default environment files: ['C:\\\\Users\\\\rlundeen\\\\.pyrit\\\\.env', 'C:\\\\Users\\\\rlundeen\\\\.pyrit\\\\.env.local']\n", + "Loaded environment file: C:\\Users\\rlundeen\\.pyrit\\.env\n", + "Loaded environment file: C:\\Users\\rlundeen\\.pyrit\\.env.local\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Loading datasets - this can take a few minutes: 100%|██████████| 45/45 [00:00<00:00, 70.03dataset/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Scenarios can be instantiated with your target and parameters\n" + ] + } + ], + "source": [ + "from pyrit.prompt_target import OpenAIChatTarget\n", + "from pyrit.setup import IN_MEMORY, initialize_pyrit_async\n", + "from pyrit.setup.initializers import LoadDefaultDatasets\n", + "\n", + "await initialize_pyrit_async(memory_db_type=IN_MEMORY, initializers=[LoadDefaultDatasets()]) # type: ignore\n", + "target = OpenAIChatTarget()\n", + "\n", + "# Option 1: Get class then instantiate\n", + "encoding_class = registry.get_class(\"encoding\")\n", + "scenario = encoding_class() # type: ignore\n", + "\n", + "# Pass dataset configuration to initialize_async\n", + "await scenario.initialize_async(objective_target=target)\n", + "\n", + "# Option 2: Use create_instance() shortcut\n", + "# scenario = registry.create_instance(\"encoding\", objective_target=my_target, ...)\n", + "\n", + "print(\"Scenarios can be instantiated with your target and parameters\")" + ] + }, + { + "cell_type": "markdown", + "id": "6", + "metadata": {}, + "source": [ + "## Checking Registration\n", + "\n", + "Registries support standard Python container operations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "'encoding' registered: True\n", + "'nonexistent' registered: False\n", + "Total scenarios: 5\n", + " - content_harms\n", + " - cyber\n", + " - encoding\n" + ] + } + ], + "source": [ + "# Check if a name is registered\n", + "print(f\"'encoding' registered: {'encoding' in registry}\")\n", + "print(f\"'nonexistent' registered: {'nonexistent' in registry}\")\n", + "\n", + "# Get count of registered classes\n", + "print(f\"Total scenarios: {len(registry)}\")\n", + "\n", + "# Iterate over names\n", + "for name in list(registry)[:3]:\n", + " print(f\" - {name}\")" + ] + }, + { + "cell_type": "markdown", + "id": "8", + "metadata": {}, + "source": [ + "## Using different registries\n", + "\n", + "There can be multiple registries. Below is doing a similar thing with the `InitializerRegistry`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Available initializers: ['airt', 'load_default_datasets', 'objective_list', 'openai_objective_target', 'simple']...\n", + "\n", + "airt:\n", + " Class: AIRTInitializer\n", + " Description: AI Red Team setup with Azure OpenAI converters, composite harm/objective scorers...\n", + "\n", + "load_default_datasets:\n", + " Class: LoadDefaultDatasets\n", + " Description: This configuration uses the DatasetLoader to load default datasets into memory.\n", + "...\n" + ] + } + ], + "source": [ + "from pyrit.registry import InitializerRegistry\n", + "\n", + "registry = InitializerRegistry.get_registry_singleton()\n", + "\n", + "# Get all registered names\n", + "names = registry.get_names()\n", + "print(f\"Available initializers: {names[:5]}...\") # Show first 5\n", + "\n", + "# Get detailed metadata\n", + "metadata = registry.list_metadata()\n", + "for item in metadata[:2]: # Show first 2\n", + " print(f\"\\n{item.name}:\")\n", + " print(f\" Class: {item.class_name}\")\n", + " print(f\" Description: {item.description[:80]}...\")" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/code/registry/1_class_registry.py b/doc/code/registry/1_class_registry.py new file mode 100644 index 000000000..2f9b2f3bb --- /dev/null +++ b/doc/code/registry/1_class_registry.py @@ -0,0 +1,110 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.18.1 +# kernelspec: +# display_name: pyrit (3.13.5) +# language: python +# name: python3 +# --- + +# %% [markdown] +# ## Listing Available Classes +# +# Use `get_names()` to see what's available, or `list_metadata()` for detailed information. + +# %% +from pyrit.registry import ScenarioRegistry + +registry = ScenarioRegistry.get_registry_singleton() + +# Get all registered names +names = registry.get_names() +print(f"Available scenarios: {names[:5]}...") # Show first 5 + +# Get detailed metadata +metadata = registry.list_metadata() +for item in metadata[:2]: # Show first 2 + print(f"\n{item.name}:") + print(f" Class: {item.class_name}") + print(f" Description: {item.description[:80]}...") + +# %% [markdown] +# ## Getting a Class +# +# Use `get_class()` to retrieve a class by name. This returns the class itself, not an instance. + +# %% +# Get a scenario class + +scenario_class = registry.get_class("encoding") + +print(f"Got class: {scenario_class}") +print(f"Class name: {scenario_class.__name__}") + +# %% [markdown] +# ## Creating Instances +# +# Once you have a class, instantiate it with your parameters. You can also use `create_instance()` as a shortcut. + +# %% +from pyrit.prompt_target import OpenAIChatTarget +from pyrit.setup import IN_MEMORY, initialize_pyrit_async +from pyrit.setup.initializers import LoadDefaultDatasets + +await initialize_pyrit_async(memory_db_type=IN_MEMORY, initializers=[LoadDefaultDatasets()]) # type: ignore +target = OpenAIChatTarget() + +# Option 1: Get class then instantiate +encoding_class = registry.get_class("encoding") +scenario = encoding_class() # type: ignore + +# Pass dataset configuration to initialize_async +await scenario.initialize_async(objective_target=target) + +# Option 2: Use create_instance() shortcut +# scenario = registry.create_instance("encoding", objective_target=my_target, ...) + +print("Scenarios can be instantiated with your target and parameters") + +# %% [markdown] +# ## Checking Registration +# +# Registries support standard Python container operations. + +# %% +# Check if a name is registered +print(f"'encoding' registered: {'encoding' in registry}") +print(f"'nonexistent' registered: {'nonexistent' in registry}") + +# Get count of registered classes +print(f"Total scenarios: {len(registry)}") + +# Iterate over names +for name in list(registry)[:3]: + print(f" - {name}") + +# %% [markdown] +# ## Using different registries +# +# There can be multiple registries. Below is doing a similar thing with the `InitializerRegistry`. + +# %% +from pyrit.registry import InitializerRegistry + +registry = InitializerRegistry.get_registry_singleton() + +# Get all registered names +names = registry.get_names() +print(f"Available initializers: {names[:5]}...") # Show first 5 + +# Get detailed metadata +metadata = registry.list_metadata() +for item in metadata[:2]: # Show first 2 + print(f"\n{item.name}:") + print(f" Class: {item.class_name}") + print(f" Description: {item.description[:80]}...") diff --git a/doc/code/registry/2_instance_registry.ipynb b/doc/code/registry/2_instance_registry.ipynb new file mode 100644 index 000000000..6db8de7af --- /dev/null +++ b/doc/code/registry/2_instance_registry.ipynb @@ -0,0 +1,211 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0", + "metadata": {}, + "source": [ + "## Why Instance Registries?\n", + "\n", + "Some components need configuration that can't easily be passed at instantiation time. For example, scorers often need:\n", + "- A configured `chat_target` for LLM-based scoring\n", + "- Specific prompt templates\n", + "- Other dependencies\n", + "\n", + "Instance registries let initializers register fully-configured instances that are ready to use." + ] + }, + { + "cell_type": "markdown", + "id": "1", + "metadata": {}, + "source": [ + "## Listing Available Instances\n", + "\n", + "Use `get_names()` to see registered instances, or `list_metadata()` for details." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found default environment files: ['C:\\\\Users\\\\rlundeen\\\\.pyrit\\\\.env', 'C:\\\\Users\\\\rlundeen\\\\.pyrit\\\\.env.local']\n", + "Loaded environment file: C:\\Users\\rlundeen\\.pyrit\\.env\n", + "Loaded environment file: C:\\Users\\rlundeen\\.pyrit\\.env.local\n", + "Registered scorers: ['self_ask_refusal_d9007ba2']\n" + ] + } + ], + "source": [ + "from pyrit.prompt_target import OpenAIChatTarget\n", + "from pyrit.registry import ScorerRegistry\n", + "from pyrit.score import SelfAskRefusalScorer\n", + "from pyrit.setup import IN_MEMORY, initialize_pyrit_async\n", + "\n", + "await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore\n", + "\n", + "# Get the registry singleton\n", + "registry = ScorerRegistry.get_registry_singleton()\n", + "\n", + "# Register a scorer instance for demonstration\n", + "chat_target = OpenAIChatTarget()\n", + "refusal_scorer = SelfAskRefusalScorer(chat_target=chat_target)\n", + "registry.register_instance(refusal_scorer)\n", + "\n", + "# List what's available\n", + "names = registry.get_names()\n", + "print(f\"Registered scorers: {names}\")" + ] + }, + { + "cell_type": "markdown", + "id": "3", + "metadata": {}, + "source": [ + "## Getting an Instance\n", + "\n", + "Use `get()` to retrieve a pre-configured instance by name. The instance is ready to use immediately." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Retrieved scorer: \n", + "Scorer type: SelfAskRefusalScorer\n" + ] + } + ], + "source": [ + "# Get the first registered scorer\n", + "if names:\n", + " scorer_name = names[0]\n", + " scorer = registry.get(scorer_name)\n", + " print(f\"Retrieved scorer: {scorer}\")\n", + " print(f\"Scorer type: {type(scorer).__name__}\")" + ] + }, + { + "cell_type": "markdown", + "id": "5", + "metadata": {}, + "source": [ + "## Inspecting Metadata\n", + "\n", + "Scorer metadata includes the scorer type and identifier for tracking." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "self_ask_refusal_d9007ba2:\n", + " Class: SelfAskRefusalScorer\n", + " Type: true_false\n", + " Description: A self-ask scorer that detects refusal in AI responses. This...\n", + "\n", + "\u001b[1m 📊 Scorer Information\u001b[0m\n", + "\u001b[37m ▸ Scorer Identifier\u001b[0m\n", + "\u001b[36m • Scorer Type: SelfAskRefusalScorer\u001b[0m\n", + "\u001b[36m • Target Model: gpt-40\u001b[0m\n", + "\u001b[36m • Temperature: None\u001b[0m\n", + "\u001b[36m • Score Aggregator: OR_\u001b[0m\n", + "\n", + "\u001b[37m ▸ Performance Metrics\u001b[0m\n", + "\u001b[33m Official evaluation has not been run yet for this specific configuration\u001b[0m\n" + ] + } + ], + "source": [ + "from pyrit.score import ConsoleScorerPrinter\n", + "\n", + "# Get metadata for all registered scorers\n", + "metadata = registry.list_metadata()\n", + "for item in metadata:\n", + " print(f\"\\n{item.name}:\")\n", + " print(f\" Class: {item.class_name}\")\n", + " print(f\" Type: {item.scorer_type}\")\n", + " print(f\" Description: {item.description[:60]}...\")\n", + "\n", + " ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item.scorer_identifier)" + ] + }, + { + "cell_type": "markdown", + "id": "7", + "metadata": {}, + "source": [ + "## Filtering\n", + "\n", + "Use `list_metadata()` with `include_filters` and `exclude_filters` dictionaries to filter scorers by any metadata property. `include_filters` requires ALL criteria to match (AND logic). `exclude_filters` excludes items matching ANY criteria. Filters use exact match for simple types and membership check for list types." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True/False scorers: ['self_ask_refusal_d9007ba2']\n", + "Refusal scorers: ['self_ask_refusal_d9007ba2']\n", + "True/False refusal scorers: ['self_ask_refusal_d9007ba2']\n" + ] + } + ], + "source": [ + "# Filter by scorer_type (based on isinstance check against TrueFalseScorer/FloatScaleScorer)\n", + "true_false_scorers = registry.list_metadata(include_filters={\"scorer_type\": \"true_false\"})\n", + "print(f\"True/False scorers: {[m.name for m in true_false_scorers]}\")\n", + "\n", + "# Filter by class_name\n", + "refusal_scorers = registry.list_metadata(include_filters={\"class_name\": \"SelfAskRefusalScorer\"})\n", + "print(f\"Refusal scorers: {[m.name for m in refusal_scorers]}\")\n", + "\n", + "# Combine multiple filters (AND logic)\n", + "specific_scorers = registry.list_metadata(\n", + " include_filters={\"scorer_type\": \"true_false\", \"class_name\": \"SelfAskRefusalScorer\"}\n", + ")\n", + "print(f\"True/False refusal scorers: {[m.name for m in specific_scorers]}\")" + ] + } + ], + "metadata": { + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/code/registry/2_instance_registry.py b/doc/code/registry/2_instance_registry.py new file mode 100644 index 000000000..5026beb28 --- /dev/null +++ b/doc/code/registry/2_instance_registry.py @@ -0,0 +1,93 @@ +# --- +# jupyter: +# jupytext: +# text_representation: +# extension: .py +# format_name: percent +# format_version: '1.3' +# jupytext_version: 1.18.1 +# --- + +# %% [markdown] +# ## Why Instance Registries? +# +# Some components need configuration that can't easily be passed at instantiation time. For example, scorers often need: +# - A configured `chat_target` for LLM-based scoring +# - Specific prompt templates +# - Other dependencies +# +# Instance registries let initializers register fully-configured instances that are ready to use. + +# %% [markdown] +# ## Listing Available Instances +# +# Use `get_names()` to see registered instances, or `list_metadata()` for details. + +# %% +from pyrit.prompt_target import OpenAIChatTarget +from pyrit.registry import ScorerRegistry +from pyrit.score import SelfAskRefusalScorer +from pyrit.setup import IN_MEMORY, initialize_pyrit_async + +await initialize_pyrit_async(memory_db_type=IN_MEMORY) # type: ignore + +# Get the registry singleton +registry = ScorerRegistry.get_registry_singleton() + +# Register a scorer instance for demonstration +chat_target = OpenAIChatTarget() +refusal_scorer = SelfAskRefusalScorer(chat_target=chat_target) +registry.register_instance(refusal_scorer) + +# List what's available +names = registry.get_names() +print(f"Registered scorers: {names}") + +# %% [markdown] +# ## Getting an Instance +# +# Use `get()` to retrieve a pre-configured instance by name. The instance is ready to use immediately. + +# %% +# Get the first registered scorer +if names: + scorer_name = names[0] + scorer = registry.get(scorer_name) + print(f"Retrieved scorer: {scorer}") + print(f"Scorer type: {type(scorer).__name__}") + +# %% [markdown] +# ## Inspecting Metadata +# +# Scorer metadata includes the scorer type and identifier for tracking. + +# %% +from pyrit.score import ConsoleScorerPrinter + +# Get metadata for all registered scorers +metadata = registry.list_metadata() +for item in metadata: + print(f"\n{item.name}:") + print(f" Class: {item.class_name}") + print(f" Type: {item.scorer_type}") + print(f" Description: {item.description[:60]}...") + + ConsoleScorerPrinter().print_objective_scorer(scorer_identifier=item.scorer_identifier) + +# %% [markdown] +# ## Filtering +# +# Use `list_metadata()` with `include_filters` and `exclude_filters` dictionaries to filter scorers by any metadata property. `include_filters` requires ALL criteria to match (AND logic). `exclude_filters` excludes items matching ANY criteria. Filters use exact match for simple types and membership check for list types. + +# %% +# Filter by scorer_type (based on isinstance check against TrueFalseScorer/FloatScaleScorer) +true_false_scorers = registry.list_metadata(include_filters={"scorer_type": "true_false"}) +print(f"True/False scorers: {[m.name for m in true_false_scorers]}") + +# Filter by class_name +refusal_scorers = registry.list_metadata(include_filters={"class_name": "SelfAskRefusalScorer"}) +print(f"Refusal scorers: {[m.name for m in refusal_scorers]}") + +# Combine multiple filters (AND logic) +specific_scorers = registry.list_metadata(include_filters={"scorer_type": "true_false", "class_name": "SelfAskRefusalScorer"}) +print(f"True/False refusal scorers: {[m.name for m in specific_scorers]}") diff --git a/pyrit/cli/frontend_core.py b/pyrit/cli/frontend_core.py index d6bdd0650..2d9279258 100644 --- a/pyrit/cli/frontend_core.py +++ b/pyrit/cli/frontend_core.py @@ -19,7 +19,7 @@ import logging import sys from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence, TypedDict +from typing import TYPE_CHECKING, Any, Callable, Optional, Sequence try: import termcolor # type: ignore @@ -39,9 +39,13 @@ def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: i if TYPE_CHECKING: - from pyrit.cli.initializer_registry import InitializerInfo, InitializerRegistry - from pyrit.cli.scenario_registry import ScenarioRegistry from pyrit.models.scenario_result import ScenarioResult + from pyrit.registry import ( + InitializerMetadata, + InitializerRegistry, + ScenarioMetadata, + ScenarioRegistry, + ) logger = logging.getLogger(__name__) @@ -51,19 +55,6 @@ def cprint(text: str, color: str = None, attrs: list = None) -> None: # type: i AZURE_SQL = "AzureSQL" -class ScenarioInfo(TypedDict): - """Type definition for scenario information dictionary.""" - - name: str - class_name: str - description: str - default_strategy: str - all_strategies: list[str] - aggregate_strategies: list[str] - default_datasets: list[str] - max_dataset_size: Optional[int] - - class FrontendCore: """ Shared context for PyRIT operations. @@ -114,8 +105,7 @@ async def initialize_async(self) -> None: if self._initialized: return - from pyrit.cli.initializer_registry import InitializerRegistry - from pyrit.cli.scenario_registry import ScenarioRegistry + from pyrit.registry import InitializerRegistry, ScenarioRegistry from pyrit.setup import initialize_pyrit_async # Initialize PyRIT without initializers (they run per-scenario) @@ -126,8 +116,8 @@ async def initialize_async(self) -> None: env_files=self._env_files, ) - # Load registries - self._scenario_registry = ScenarioRegistry() + # Load registries (use singleton pattern for shared access) + self._scenario_registry = ScenarioRegistry.get_registry_singleton() if self._initialization_scripts: print("Discovering user scenarios...") sys.stdout.flush() @@ -168,43 +158,43 @@ def initializer_registry(self) -> "InitializerRegistry": return self._initializer_registry -async def list_scenarios_async(*, context: FrontendCore) -> list[ScenarioInfo]: +async def list_scenarios_async(*, context: FrontendCore) -> list[ScenarioMetadata]: """ - List all available scenarios. + List metadata for all available scenarios. Args: context: PyRIT context with loaded registries. Returns: - List of scenario info dictionaries. + List of scenario metadata dictionaries describing each scenario class. """ if not context._initialized: await context.initialize_async() - return context.scenario_registry.list_scenarios() + return context.scenario_registry.list_metadata() async def list_initializers_async( *, context: FrontendCore, discovery_path: Optional[Path] = None -) -> "Sequence[InitializerInfo]": +) -> "Sequence[InitializerMetadata]": """ - List all available initializers. + List metadata for all available initializers. Args: context: PyRIT context with loaded registries. discovery_path: Optional path to discover initializers from. Returns: - Sequence of initializer info dictionaries. + Sequence of initializer metadata dictionaries describing each initializer class. """ if discovery_path: - from pyrit.cli.initializer_registry import InitializerRegistry + from pyrit.registry import InitializerRegistry registry = InitializerRegistry(discovery_path=discovery_path) - return registry.list_initializers() + return registry.list_metadata() if not context._initialized: await context.initialize_async() - return context.initializer_registry.list_initializers() + return context.initializer_registry.list_metadata() async def run_scenario_async( @@ -263,7 +253,7 @@ async def run_scenario_async( initializer_instances = [] for name in context._initializer_names: - initializer_class = context.initializer_registry.get_initializer_class(name=name) + initializer_class = context.initializer_registry.get_class(name) initializer_instances.append(initializer_class()) # Re-initialize PyRIT with the scenario-specific initializers @@ -276,10 +266,10 @@ async def run_scenario_async( ) # Get scenario class - scenario_class = context.scenario_registry.get_scenario(scenario_name) + scenario_class = context.scenario_registry.get_class(scenario_name) if scenario_class is None: - available = ", ".join(context.scenario_registry.get_scenario_names()) + available = ", ".join(context.scenario_registry.get_names()) raise ValueError(f"Scenario '{scenario_name}' not found.\nAvailable scenarios: {available}") # Build initialization kwargs (these go to initialize_async, not __init__) @@ -387,39 +377,39 @@ def _print_header(*, text: str) -> None: print(f"\n {text}") -def format_scenario_info(*, scenario_info: ScenarioInfo) -> None: +def format_scenario_metadata(*, scenario_metadata: ScenarioMetadata) -> None: """ - Print formatted information about a scenario. + Print formatted information about a scenario class. Args: - scenario_info: Dictionary containing scenario information. + scenario_metadata: Dataclass containing scenario metadata. """ - _print_header(text=scenario_info["name"]) - print(f" Class: {scenario_info['class_name']}") + _print_header(text=scenario_metadata.name) + print(f" Class: {scenario_metadata.class_name}") - description = scenario_info.get("description", "") + description = scenario_metadata.description if description: print(" Description:") print(_format_wrapped_text(text=description, indent=" ")) - if scenario_info.get("aggregate_strategies"): - agg_strategies = scenario_info["aggregate_strategies"] + if scenario_metadata.aggregate_strategies: + agg_strategies = scenario_metadata.aggregate_strategies print(" Aggregate Strategies:") formatted = _format_wrapped_text(text=", ".join(agg_strategies), indent=" - ") print(formatted) - if scenario_info.get("all_strategies"): - strategies = scenario_info["all_strategies"] + if scenario_metadata.all_strategies: + strategies = scenario_metadata.all_strategies print(f" Available Strategies ({len(strategies)}):") formatted = _format_wrapped_text(text=", ".join(strategies), indent=" ") print(formatted) - if scenario_info.get("default_strategy"): - print(f" Default Strategy: {scenario_info['default_strategy']}") + if scenario_metadata.default_strategy: + print(f" Default Strategy: {scenario_metadata.default_strategy}") - if scenario_info.get("default_datasets"): - datasets = scenario_info["default_datasets"] - max_size = scenario_info.get("max_dataset_size") + if scenario_metadata.default_datasets: + datasets = scenario_metadata.default_datasets + max_size = scenario_metadata.max_dataset_size if datasets: size_suffix = f", max {max_size} per dataset" if max_size else "" print(f" Default Datasets ({len(datasets)}{size_suffix}):") @@ -429,28 +419,28 @@ def format_scenario_info(*, scenario_info: ScenarioInfo) -> None: print(" Default Datasets: None") -def format_initializer_info(*, initializer_info: "InitializerInfo") -> None: +def format_initializer_metadata(*, initializer_metadata: "InitializerMetadata") -> None: """ - Print formatted information about an initializer. + Print formatted information about an initializer class. Args: - initializer_info: Dictionary containing initializer information. + initializer_metadata: Dataclass containing initializer metadata. """ - _print_header(text=initializer_info["name"]) - print(f" Class: {initializer_info['class_name']}") - print(f" Name: {initializer_info['initializer_name']}") - print(f" Execution Order: {initializer_info['execution_order']}") + _print_header(text=initializer_metadata.name) + print(f" Class: {initializer_metadata.class_name}") + print(f" Name: {initializer_metadata.initializer_name}") + print(f" Execution Order: {initializer_metadata.execution_order}") - if initializer_info.get("required_env_vars"): + if initializer_metadata.required_env_vars: print(" Required Environment Variables:") - for env_var in initializer_info["required_env_vars"]: + for env_var in initializer_metadata.required_env_vars: print(f" - {env_var}") else: print(" Required Environment Variables: None") - if initializer_info.get("description"): + if initializer_metadata.description: print(" Description:") - print(_format_wrapped_text(text=initializer_info["description"], indent=" ")) + print(_format_wrapped_text(text=initializer_metadata.description, indent=" ")) def validate_database(*, database: str) -> str: @@ -605,7 +595,7 @@ def resolve_initialization_scripts(script_paths: list[str]) -> list[Path]: Raises: FileNotFoundError: If a script path does not exist. """ - from pyrit.cli.initializer_registry import InitializerRegistry + from pyrit.registry import InitializerRegistry return InitializerRegistry.resolve_script_paths(script_paths=script_paths) @@ -709,8 +699,8 @@ async def print_scenarios_list_async(*, context: FrontendCore) -> int: print("\nAvailable Scenarios:") print("=" * 80) - for scenario_info in scenarios: - format_scenario_info(scenario_info=scenario_info) + for scenario_metadata in scenarios: + format_scenario_metadata(scenario_metadata=scenario_metadata) print("\n" + "=" * 80) print(f"\nTotal scenarios: {len(scenarios)}") return 0 @@ -735,8 +725,8 @@ async def print_initializers_list_async(*, context: FrontendCore, discovery_path print("\nAvailable Initializers:") print("=" * 80) - for initializer_info in initializers: - format_initializer_info(initializer_info=initializer_info) + for initializer_metadata in initializers: + format_initializer_metadata(initializer_metadata=initializer_metadata) print("\n" + "=" * 80) print(f"\nTotal initializers: {len(initializers)}") return 0 diff --git a/pyrit/cli/initializer_registry.py b/pyrit/cli/initializer_registry.py deleted file mode 100644 index 950e46510..000000000 --- a/pyrit/cli/initializer_registry.py +++ /dev/null @@ -1,330 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from __future__ import annotations - -""" -Initializer registry for discovering and cataloging PyRIT initializers. - -This module provides functionality to discover all available PyRITInitializer subclasses. - -PERFORMANCE OPTIMIZATION: -This module uses lazy imports and direct path computation to minimize import overhead: - -1. Lazy Imports via TYPE_CHECKING: PyRITInitializer is only imported for type checking, - not at runtime. Runtime imports happen inside methods when actually needed. - -2. Direct Path Computation: Computes PYRIT_PATH directly using __file__ instead of importing - from pyrit.common.path, avoiding loading of the pyrit package. -""" - -import importlib.util -import inspect -import logging -from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, TypedDict - -# Compute PYRIT_PATH directly to avoid importing pyrit package -# (which triggers heavy imports from __init__.py) -PYRIT_PATH = Path(__file__).parent.parent.resolve() - -if TYPE_CHECKING: - from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer - -logger = logging.getLogger(__name__) - - -class InitializerInfo(TypedDict): - """Type definition for initializer information dictionary.""" - - name: str - class_name: str - initializer_name: str - description: str - required_env_vars: list[str] - execution_order: int - - -class InitializerRegistry: - """ - Registry for discovering and managing available initializers. - - This class discovers all PyRITInitializer subclasses from the - pyrit/setup/initializers directory structure. - - Initializers are identified by their filename (e.g., "objective_target", "simple"). - The directory structure is used for organization but not exposed to users. - """ - - def __init__(self, *, discovery_path: Path | None = None) -> None: - """ - Initialize the initializer registry. - - Args: - discovery_path (Path | None): The path to discover initializers from. - If None, defaults to pyrit/setup/initializers (discovers all). - To discover only scenarios, pass pyrit/setup/initializers/scenarios. - """ - self._initializers: Dict[str, InitializerInfo] = {} - self._initializer_paths: Dict[str, Path] = {} # Track file paths for collision detection - self._initializer_metadata: Optional[List[InitializerInfo]] = None - - if discovery_path is None: - discovery_path = Path(PYRIT_PATH) / "setup" / "initializers" - - self._discovery_path = discovery_path - - self._discover_initializers() - - def _discover_initializers(self) -> None: - """ - Discover all initializers from the specified discovery path. - - This method recursively walks the directory tree and registers - any PyRITInitializer subclasses found. Initializers are registered - by filename only for simpler user experience. - """ - if not self._discovery_path.exists(): - logger.warning(f"Initializers directory not found: {self._discovery_path}") - return - - # Check if discovery path is a file or directory - if self._discovery_path.is_file(): - self._process_file(file_path=self._discovery_path) - elif self._discovery_path.is_dir(): - # Discover from the specified directory and its subdirectories - self._discover_in_directory(directory=self._discovery_path) - - def _discover_in_directory(self, *, directory: Path) -> None: - """ - Recursively discover PyRIT initializers in a directory. - - Args: - directory (Path): The directory to search for initializer modules. - """ - for item in directory.iterdir(): - if item.is_file() and item.suffix == ".py" and item.stem != "__init__": - self._process_file(file_path=item) - elif item.is_dir() and item.name != "__pycache__": - self._discover_in_directory(directory=item) - - def _process_file(self, *, file_path: Path) -> None: - """ - Process a Python file to extract PyRITInitializer subclasses. - - Args: - file_path (Path): Path to the Python file to process. - """ - # Runtime import to avoid loading heavy modules at module level - from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer - - # Calculate module name for import (still needs full path for Python import) - # Convert file path to module path relative to initializers directory - initializers_base = Path(PYRIT_PATH) / "setup" / "initializers" - relative_path = file_path.relative_to(initializers_base) - module_parts = list(relative_path.parts[:-1]) + [relative_path.stem] - module_name = ".".join(module_parts) - - # Use just the filename as the name (e.g., "load_default_datasets") - short_name = file_path.stem - - # Check for name collision - if short_name in self._initializer_paths: - existing_path = self._initializer_paths[short_name] - logger.error( - f"Initializer name collision: '{short_name}' found in both " - f"'{file_path}' and '{existing_path}'. " - f"Initializer filenames must be unique across all directories." - ) - return - - try: - spec = importlib.util.spec_from_file_location(f"pyrit.setup.initializers.{module_name}", file_path) - if not spec or not spec.loader: - return - - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Find all PyRITInitializer subclasses in the module - for attr_name in dir(module): - attr = getattr(module, attr_name) - if inspect.isclass(attr) and issubclass(attr, PyRITInitializer) and attr != PyRITInitializer: - self._try_register_initializer(initializer_class=attr, short_name=short_name, file_path=file_path) - - except Exception as e: - logger.warning(f"Failed to load initializer module {short_name}: {e}") - - def _try_register_initializer( - self, *, initializer_class: type[PyRITInitializer], short_name: str, file_path: Path - ) -> None: - """ - Try to instantiate an initializer and register it. - - Args: - initializer_class (type[PyRITInitializer]): The initializer class to instantiate. - short_name (str): The short name for the initializer (filename without extension). - file_path (Path): The path to the file containing the initializer. - """ - try: - instance = initializer_class() - initializer_info: InitializerInfo = { - "name": short_name, - "class_name": initializer_class.__name__, - "initializer_name": instance.name, - "description": instance.description, - "required_env_vars": instance.required_env_vars, - "execution_order": instance.execution_order, - } - self._initializers[short_name] = initializer_info - self._initializer_paths[short_name] = file_path - logger.debug(f"Registered initializer: {short_name} ({initializer_class.__name__})") - - except Exception as e: - logger.warning(f"Failed to instantiate initializer {initializer_class.__name__}: {e}") - - def get_initializer(self, name: str) -> InitializerInfo | None: - """ - Get an initializer by name. - - Args: - name (str): Initializer identifier (e.g., "objective_target", "simple") - - Returns: - InitializerInfo | None: The initializer information, or None if not found. - """ - return self._initializers.get(name) - - def list_initializers(self) -> List[InitializerInfo]: - """ - List all available initializers with their metadata. - - Returns: - List[InitializerInfo]: List of initializer information dictionaries, sorted by - execution order and then by name. - """ - # Return cached metadata if available - if self._initializer_metadata is not None: - return self._initializer_metadata - - # Build from discovered initializers - initializers_list = list(self._initializers.values()) - initializers_list.sort(key=lambda x: (x["execution_order"], x["name"])) - - # Cache for subsequent calls - self._initializer_metadata = initializers_list - - return initializers_list - - def get_initializer_names(self) -> List[str]: - """ - Get a list of all available initializer names. - - Returns: - List[str]: Sorted list of initializer identifiers. - """ - return sorted(self._initializers.keys()) - - def resolve_initializer_paths(self, *, initializer_names: list[str]) -> list[Path]: - """ - Resolve initializer names to their file paths. - - Args: - initializer_names (list[str]): List of initializer names to resolve. - - Returns: - list[Path]: List of resolved file paths. - - Raises: - ValueError: If any initializer name is not found or has no file path. - """ - resolved_paths = [] - - for initializer_name in initializer_names: - initializer_info = self.get_initializer(initializer_name) - - if initializer_info is None: - available = ", ".join(sorted(self.get_initializer_names())) - raise ValueError( - f"Built-in initializer '{initializer_name}' not found.\n" - f"Available initializers: {available}\n" - f"Use 'pyrit_scan --list-initializers' to see detailed information." - ) - - initializer_file = self._initializer_paths.get(initializer_name) - if initializer_file is None: - raise ValueError(f"Could not locate file for initializer '{initializer_name}'.") - - resolved_paths.append(initializer_file) - - return resolved_paths - - def get_initializer_class(self, *, name: str) -> type["PyRITInitializer"]: - """ - Get the initializer class by name. - - Args: - name: The initializer name. - - Returns: - The initializer class. - - Raises: - ValueError: If initializer not found. - """ - import importlib.util - - initializer_info = self.get_initializer(name) - if initializer_info is None: - available = ", ".join(sorted(self.get_initializer_names())) - raise ValueError( - f"Initializer '{name}' not found.\n" - f"Available initializers: {available}\n" - f"Use 'pyrit_scan --list-initializers' to see detailed information." - ) - - initializer_file = self._initializer_paths.get(name) - if initializer_file is None: - raise ValueError(f"Could not locate file for initializer '{name}'.") - - # Load the module - spec = importlib.util.spec_from_file_location("initializer_module", initializer_file) - if spec is None or spec.loader is None: - raise ValueError(f"Failed to load initializer from {initializer_file}") - - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - - # Get the initializer class - initializer_class = getattr(module, initializer_info["class_name"]) - return initializer_class - - @staticmethod - def resolve_script_paths(*, script_paths: list[str]) -> list[Path]: - """ - Resolve and validate custom script paths. - - Args: - script_paths (list[str]): List of script path strings to resolve. - - Returns: - list[Path]: List of resolved Path objects. - - Raises: - FileNotFoundError: If any script path does not exist. - """ - resolved_paths = [] - - for script in script_paths: - script_path = Path(script) - if not script_path.is_absolute(): - script_path = Path.cwd() / script_path - - if not script_path.exists(): - raise FileNotFoundError( - f"Initialization script not found: {script_path}\n Looked in: {script_path.absolute()}" - ) - - resolved_paths.append(script_path) - - return resolved_paths diff --git a/pyrit/cli/scenario_registry.py b/pyrit/cli/scenario_registry.py deleted file mode 100644 index fd2fcb6de..000000000 --- a/pyrit/cli/scenario_registry.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from __future__ import annotations - -""" -Scenario registry for discovering and instantiating PyRIT scenarios. - -This module provides functionality to discover all available Scenario subclasses -from the pyrit.scenario.scenarios module and from user-defined initialization scripts. - -PERFORMANCE OPTIMIZATION: -This module uses lazy imports to minimize overhead during CLI operations: - -1. Lazy Imports via TYPE_CHECKING: Heavy dependencies (like Scenario base class) are only - imported for type checking, not at runtime. Runtime imports happen inside methods only - when actually needed. - -2. Direct Path Computation: Computes PYRIT_PATH directly using __file__ instead of importing - from pyrit.common.path, which would trigger loading the entire pyrit package (including - heavy dependencies like transformers). -""" - -import importlib -import inspect -import logging -import pkgutil -from pathlib import Path -from typing import TYPE_CHECKING, Dict, List, Optional, Type - -# Compute PYRIT_PATH directly to avoid importing pyrit package -# (which triggers heavy imports from __init__.py) -PYRIT_PATH = Path(__file__).parent.parent.resolve() - -# Lazy import to avoid loading heavy scenario modules when just listing scenarios -if TYPE_CHECKING: - from pyrit.cli.frontend_core import ScenarioInfo - from pyrit.scenario.core import Scenario - -logger = logging.getLogger(__name__) - - -class ScenarioRegistry: - """ - Registry for discovering and managing available scenarios. - - This class discovers all Scenario subclasses from: - 1. Built-in scenarios in pyrit.scenario.scenarios module - 2. User-defined scenarios from initialization scripts (set via globals) - - Scenarios are identified by their simple name (e.g., "encoding", "foundry"). - """ - - def __init__(self) -> None: - """Initialize the scenario registry with lazy discovery.""" - self._scenarios: Dict[str, Type[Scenario]] = {} - self._scenario_metadata: Optional[List[ScenarioInfo]] = None - self._discovered = False - - def _ensure_discovered(self) -> None: - """Ensure scenarios have been discovered. Discovers on first call only.""" - if not self._discovered: - self._discover_builtin_scenarios() - self._discovered = True - - def _discover_builtin_scenarios(self) -> None: - """ - Discover all built-in scenarios from pyrit.scenario.scenarios module. - - This method dynamically imports all modules in the scenarios package - and registers any Scenario subclasses found. - """ - from pyrit.scenario.core import Scenario - - try: - import pyrit.scenario.scenarios as scenarios_package - - # Get the path to the scenarios package - package_file = scenarios_package.__file__ - if package_file is None: - # Try using __path__ instead - if hasattr(scenarios_package, "__path__"): - package_path = Path(scenarios_package.__path__[0]) - else: - logger.error("Cannot determine scenarios package location") - return - else: - package_path = Path(package_file).parent - - # Iterate through all Python files in the scenarios directory and subdirectories - def discover_modules(base_path: Path, base_module: str, relative_prefix: str = "") -> None: - """Recursively discover modules in the scenarios package and subdirectories.""" - for _, module_name, is_pkg in pkgutil.iter_modules([str(base_path)]): - if module_name.startswith("_"): - continue - - # Build the full module name correctly - if base_module: - full_module_name = f"{base_module}.{module_name}" - else: - full_module_name = f"pyrit.scenario.scenarios.{module_name}" - - # Build the relative path for the scenario name - if relative_prefix: - scenario_name = f"{relative_prefix}.{module_name}" - else: - scenario_name = module_name - - try: - # Import the module - module = importlib.import_module(full_module_name) - - # Only register scenarios if this is a file (not a package) - if not is_pkg: - # Find all Scenario subclasses in the module - for name, obj in inspect.getmembers(module, inspect.isclass): - # Check if it's a Scenario subclass (but not Scenario itself) - if issubclass(obj, Scenario) and obj is not Scenario: - self._scenarios[scenario_name] = obj - logger.debug(f"Registered built-in scenario: {scenario_name} ({obj.__name__})") - - # If it's a package, recursively discover its submodules - if is_pkg: - subpackage_path = base_path / module_name - discover_modules(subpackage_path, full_module_name, scenario_name) - - except Exception as e: - logger.warning(f"Failed to load scenario module {full_module_name}: {e}") - - # Start discovery from the scenarios package root - discover_modules(package_path, "") - - except Exception as e: - logger.error(f"Failed to discover built-in scenarios: {e}") - - def discover_user_scenarios(self) -> None: - """ - Discover user-defined scenarios from global variables. - - After initialization scripts are executed, they may define Scenario subclasses - and store them in globals. This method searches for such classes. - - User scenarios will override built-in scenarios with the same name. - """ - from pyrit.scenario.core import Scenario - - try: - # Check the global namespace for Scenario subclasses - import sys - - # Create a snapshot of modules to avoid dictionary changed size during iteration - modules_snapshot = list(sys.modules.items()) - - # Look through all loaded modules for scenario classes - for module_name, module in modules_snapshot: - if module is None or not hasattr(module, "__dict__"): - continue - - # Skip built-in and standard library modules - if module_name.startswith(("builtins", "_", "sys", "os", "importlib")): - continue - - # Look for Scenario subclasses in the module - for name, obj in inspect.getmembers(module, inspect.isclass): - if issubclass(obj, Scenario) and obj is not Scenario: - # Check if this is a user-defined class (not from pyrit.scenario.scenarios) - if not obj.__module__.startswith("pyrit.scenario.scenarios"): - # Convert class name to snake_case for scenario name - scenario_name = self._class_name_to_scenario_name(obj.__name__) - self._scenarios[scenario_name] = obj - logger.info(f"Registered user-defined scenario: {scenario_name} ({obj.__name__})") - - except Exception as e: - # Silently ignore errors during user scenario discovery - # User scenarios are optional and errors here are not critical - logger.debug(f"Failed to discover user scenarios: {e}") - - def _class_name_to_scenario_name(self, class_name: str) -> str: - """ - Convert a class name to a scenario identifier. - - Args: - class_name (str): Class name (e.g., "Encoding", "MyCustom") - - Returns: - str: Scenario identifier (e.g., "encoding", "my_custom") - """ - # Remove "Scenario" suffix if present (for backwards compatibility) - if class_name.endswith("Scenario"): - class_name = class_name[:-8] - - # Convert CamelCase to snake_case - import re - - name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", class_name) - name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() - - return name - - def get_scenario(self, name: str) -> Optional[Type[Scenario]]: - """ - Get a scenario class by name. - - Args: - name (str): Scenario identifier (e.g., "encoding", "foundry") - - Returns: - Optional[Type[Scenario]]: The scenario class, or None if not found. - """ - self._ensure_discovered() - return self._scenarios.get(name) - - def list_scenarios(self) -> List[ScenarioInfo]: - """ - List all available scenarios with their metadata. - - Returns: - List[ScenarioInfo]: List of scenario information dictionaries containing: - - name: Scenario identifier - - class_name: Class name - - description: Full class docstring - - default_strategy: The default strategy used when none specified - - all_strategies: All available strategy values - - aggregate_strategies: Aggregate strategy values - - default_datasets: Names of default datasets for this scenario - - max_dataset_size: Maximum seed groups per dataset (None if unlimited) - """ - # If we already have metadata, return it - if self._scenario_metadata is not None: - return self._scenario_metadata - - # Discover scenarios and build metadata - self._ensure_discovered() - scenarios_info: List[ScenarioInfo] = [] - - for name, scenario_class in sorted(self._scenarios.items()): - # Extract full docstring as description, clean up whitespace - doc = scenario_class.__doc__ or "" - if doc: - # Normalize whitespace: remove leading/trailing, collapse multiple spaces/newlines - description = " ".join(doc.split()) - else: - description = "No description available" - - # Get the strategy class for this scenario - strategy_class = scenario_class.get_strategy_class() - - dataset_config = scenario_class.default_dataset_config() - default_datasets = dataset_config.get_default_dataset_names() - max_dataset_size = dataset_config.max_dataset_size - - scenarios_info.append( - { - "name": name, - "class_name": scenario_class.__name__, - "description": description, - "default_strategy": scenario_class.get_default_strategy().value, - "all_strategies": [s.value for s in strategy_class.get_all_strategies()], - "aggregate_strategies": [s.value for s in strategy_class.get_aggregate_strategies()], - "default_datasets": default_datasets, - "max_dataset_size": max_dataset_size, - } - ) - - # Cache metadata for subsequent calls - self._scenario_metadata = scenarios_info - - return scenarios_info - - def get_scenario_names(self) -> List[str]: - """ - Get a list of all available scenario names. - - Returns: - List[str]: Sorted list of scenario identifiers. - """ - self._ensure_discovered() - return sorted(self._scenarios.keys()) diff --git a/pyrit/registry/__init__.py b/pyrit/registry/__init__.py new file mode 100644 index 000000000..ba0566020 --- /dev/null +++ b/pyrit/registry/__init__.py @@ -0,0 +1,44 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +"""Registry module for PyRIT class and instance registries.""" + +from pyrit.registry.base import RegistryItemMetadata, RegistryProtocol +from pyrit.registry.class_registries import ( + BaseClassRegistry, + ClassEntry, + InitializerMetadata, + InitializerRegistry, + ScenarioMetadata, + ScenarioRegistry, +) +from pyrit.registry.discovery import ( + discover_in_directory, + discover_in_package, + discover_subclasses_in_loaded_modules, +) +from pyrit.registry.instance_registries import ( + BaseInstanceRegistry, + ScorerMetadata, + ScorerRegistry, +) +from pyrit.registry.name_utils import class_name_to_registry_name, registry_name_to_class_name + +__all__ = [ + "BaseClassRegistry", + "BaseInstanceRegistry", + "ClassEntry", + "class_name_to_registry_name", + "discover_in_directory", + "discover_in_package", + "discover_subclasses_in_loaded_modules", + "InitializerMetadata", + "InitializerRegistry", + "RegistryItemMetadata", + "RegistryProtocol", + "registry_name_to_class_name", + "ScenarioMetadata", + "ScenarioRegistry", + "ScorerMetadata", + "ScorerRegistry", +] diff --git a/pyrit/registry/base.py b/pyrit/registry/base.py new file mode 100644 index 000000000..a9c284561 --- /dev/null +++ b/pyrit/registry/base.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Shared base types for PyRIT registries. + +This module contains types shared between class registries (which store Type[T]) +and instance registries (which store T instances). +""" + +from dataclasses import dataclass +from typing import Any, Dict, Iterator, List, Optional, Protocol, TypeVar, runtime_checkable + +# Type variable for metadata (invariant for Protocol compatibility) +MetadataT = TypeVar("MetadataT") + + +@runtime_checkable +class RegistryProtocol(Protocol[MetadataT]): + """ + Protocol defining the common interface for all registries. + + Both class registries (BaseClassRegistry) and instance registries + (BaseInstanceRegistry) implement this interface, enabling code that + works with either registry type. + + Type Parameters: + MetadataT: The metadata dataclass type (e.g., ScenarioMetadata). + """ + + @classmethod + def get_registry_singleton(cls) -> "RegistryProtocol[MetadataT]": + """Get the singleton instance of this registry.""" + ... + + @classmethod + def reset_instance(cls) -> None: + """Reset the singleton instance.""" + ... + + def get_names(self) -> List[str]: + """Get a sorted list of all registered names.""" + ... + + def list_metadata( + self, + *, + include_filters: Optional[Dict[str, Any]] = None, + exclude_filters: Optional[Dict[str, Any]] = None, + ) -> List[MetadataT]: + """ + List metadata for all registered items, optionally filtered. + + Args: + include_filters: Optional dict of filters that items must match. + Keys are metadata property names, values are the filter criteria. + All filters must match (AND logic). + exclude_filters: Optional dict of filters that items must NOT match. + Keys are metadata property names, values are the filter criteria. + Any matching filter excludes the item. + + Returns: + List of metadata describing each registered item. + """ + ... + + def __contains__(self, name: str) -> bool: + """Check if a name is registered.""" + ... + + def __len__(self) -> int: + """Get the count of registered items.""" + ... + + def __iter__(self) -> Iterator[str]: + """Iterate over registered names.""" + ... + + +@dataclass(frozen=True) +class RegistryItemMetadata: + """ + Base dataclass for registry item metadata. + + This dataclass provides descriptive information about a registered item + (either a class or an instance). It is NOT the item itself - it's a + structured object describing the item. + + All registry-specific metadata types should extend this with additional fields. + """ + + name: str # The snake_case registry name (e.g., "self_ask_refusal") + class_name: str # The actual class name (e.g., "SelfAskRefusalScorer") + description: str # Description from docstring or manual override + + +def _matches_filters( + metadata: Any, + *, + include_filters: Optional[Dict[str, Any]] = None, + exclude_filters: Optional[Dict[str, Any]] = None, +) -> bool: + """ + Check if a metadata object matches all provided filters. + + Supports filtering on any property of the metadata dataclass: + - For simple types (str, int, bool): exact match comparison + - For sequence types (list, tuple): checks if filter value is contained in the sequence + + Items must match ALL include_filters (AND logic) and must NOT match ANY exclude_filters. + + Args: + metadata: The metadata dataclass instance to check. + include_filters: Optional dict of filters that must ALL match. + Keys are metadata property names, values are the filter criteria. + exclude_filters: Optional dict of filters that must ALL NOT match. + Keys are metadata property names, values are the filter criteria. + + Returns: + True if all include_filters match and no exclude_filters match, False otherwise. + """ + # Check include filters - all must match + if include_filters: + for key, filter_value in include_filters.items(): + if not hasattr(metadata, key): + return False + + actual_value = getattr(metadata, key) + + # Handle sequence types - check if filter value is in the sequence + if isinstance(actual_value, (list, tuple)): + if filter_value not in actual_value: + return False + # Simple exact match for other types + elif actual_value != filter_value: + return False + + # Check exclude filters - none must match + if exclude_filters: + for key, filter_value in exclude_filters.items(): + if not hasattr(metadata, key): + # If the key doesn't exist, it can't match the exclude filter + continue + + actual_value = getattr(metadata, key) + + # Handle sequence types - exclude if filter value is in the sequence + if isinstance(actual_value, (list, tuple)): + if filter_value in actual_value: + return False + # Simple exact match for other types - exclude if it matches + elif actual_value == filter_value: + return False + + return True diff --git a/pyrit/registry/class_registries/__init__.py b/pyrit/registry/class_registries/__init__.py new file mode 100644 index 000000000..b29a6a6ee --- /dev/null +++ b/pyrit/registry/class_registries/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Class registries package. + +This package contains registries that store classes (Type[T]) which can be +instantiated on demand. Examples include ScenarioRegistry and InitializerRegistry. + +For registries that store pre-configured instances, see instance_registries/. +""" + +from pyrit.registry.class_registries.base_class_registry import ( + BaseClassRegistry, + ClassEntry, +) +from pyrit.registry.class_registries.initializer_registry import ( + InitializerMetadata, + InitializerRegistry, +) +from pyrit.registry.class_registries.scenario_registry import ( + ScenarioMetadata, + ScenarioRegistry, +) + +__all__ = [ + "BaseClassRegistry", + "ClassEntry", + "ScenarioRegistry", + "ScenarioMetadata", + "InitializerRegistry", + "InitializerMetadata", +] diff --git a/pyrit/registry/class_registries/base_class_registry.py b/pyrit/registry/class_registries/base_class_registry.py new file mode 100644 index 000000000..bd6ae3f27 --- /dev/null +++ b/pyrit/registry/class_registries/base_class_registry.py @@ -0,0 +1,401 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Base class registry for PyRIT. + +This module provides the abstract base class for registries that store classes (Type[T]). +These registries allow on-demand instantiation of registered classes. + +For registries that store pre-configured instances, see instance_registries/. + +Terminology: +- **Metadata**: A TypedDict describing a registered class (e.g., ScenarioMetadata) +- **Class**: The actual Python class (Type[T]) that can be instantiated +- **Instance**: A created object of that class +- **ClassEntry**: Internal wrapper holding a class plus optional factory/defaults +""" + +from abc import ABC, abstractmethod +from typing import Callable, Dict, Generic, Iterator, List, Optional, Type, TypeVar + +from pyrit.registry.base import RegistryItemMetadata, RegistryProtocol +from pyrit.registry.name_utils import class_name_to_registry_name + +# Type variable for the registered class type +T = TypeVar("T") +# Type variable for the metadata TypedDict +MetadataT = TypeVar("MetadataT") + + +class ClassEntry(Generic[T]): + """ + Internal wrapper for a registered class. + + This holds the class itself (Type[T]) along with optional factory + and default parameters for creating instances. + + Note: This is an internal implementation detail. Users interact with + registries via get_class(), create_instance(), and list_metadata(). + + Attributes: + registered_class: The actual Python class (Type[T]). + factory: Optional callable to create instances with custom logic. + default_kwargs: Default keyword arguments for instance creation. + description: Optional description override. + """ + + def __init__( + self, + *, + registered_class: Type[T], + factory: Optional[Callable[..., T]] = None, + default_kwargs: Optional[Dict[str, object]] = None, + description: Optional[str] = None, + ) -> None: + """ + Initialize a class entry. + + Args: + registered_class: The actual Python class (Type[T]). + factory: Optional callable that creates an instance. + default_kwargs: Default keyword arguments for instantiation. + description: Optional description override. + """ + self.registered_class = registered_class + self.factory = factory + self.default_kwargs = default_kwargs or {} + self.description = description + + def create_instance(self, **kwargs: object) -> T: + """ + Create an instance of the registered class. + + Args: + **kwargs: Additional keyword arguments. These override default_kwargs. + + Returns: + An instance of type T. + """ + merged_kwargs = {**self.default_kwargs, **kwargs} + + if self.factory is not None: + return self.factory(**merged_kwargs) + else: + return self.registered_class(**merged_kwargs) + + +class BaseClassRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]): + """ + Abstract base class for registries that store classes (Type[T]). + + This class implements RegistryProtocol and provides the common infrastructure + for class registries including: + - Lazy discovery of classes + - Registration of classes or factory callables + - Metadata caching + - Consistent API: get_class(), get_names(), list_metadata(), create_instance() + - Singleton pattern support via get_registry_singleton() + + Subclasses must implement: + - _discover(): Populate the registry with discovered classes + - _build_metadata(): Build a metadata TypedDict for a class + + Type Parameters: + T: The type of classes being registered (e.g., Scenario, PromptTarget). + MetadataT: The TypedDict type for metadata (e.g., ScenarioMetadata). + """ + + # Class-level singleton instances, keyed by registry class + _instances: Dict[type, "BaseClassRegistry[object, object]"] = {} + + def __init__(self, *, lazy_discovery: bool = True) -> None: + """ + Initialize the registry. + + Args: + lazy_discovery: If True, discovery is deferred until first access. + If False, discovery runs immediately in constructor. + """ + # Maps registry names to ClassEntry wrappers + self._class_entries: Dict[str, ClassEntry[T]] = {} + self._metadata_cache: Optional[List[MetadataT]] = None + self._discovered = False + self._lazy_discovery = lazy_discovery + + if not lazy_discovery: + self._discover() + self._discovered = True + + @classmethod + def get_registry_singleton(cls) -> "BaseClassRegistry[T, MetadataT]": + """ + Get the singleton instance of this registry. + + Creates the instance on first call with default parameters. + + Returns: + The singleton instance of this registry class. + """ + if cls not in cls._instances: + cls._instances[cls] = cls() # type: ignore[assignment] + return cls._instances[cls] # type: ignore[return-value] + + @classmethod + def reset_instance(cls) -> None: + """ + Reset the singleton instance. + + Useful for testing or when re-discovery is needed. + """ + if cls in cls._instances: + del cls._instances[cls] + + def _ensure_discovered(self) -> None: + """Ensure discovery has been performed. Runs discovery on first access.""" + if not self._discovered: + self._discover() + self._discovered = True + + @abstractmethod + def _discover(self) -> None: + """ + Perform discovery of registry classes. + + Subclasses implement this to populate self._class_entries with discovered classes. + """ + pass + + @abstractmethod + def _build_metadata(self, name: str, entry: ClassEntry[T]) -> MetadataT: + """ + Build metadata dictionary for a registered class. + + Subclasses must implement this to provide registry-specific metadata. + + Args: + name: The registry name (snake_case identifier). + entry: The ClassEntry containing the registered class. + + Returns: + A metadata dataclass with descriptive information about the registered class. + """ + pass + + def _build_base_metadata(self, name: str, entry: ClassEntry[T]) -> RegistryItemMetadata: + """ + Build the common base metadata for a registered class. + + This helper extracts fields common to all registries: name, class_name, description. + Subclasses can use this for building common fields if needed. + + Args: + name: The registry name (snake_case identifier). + entry: The ClassEntry containing the registered class. + + Returns: + A RegistryItemMetadata dataclass with common fields. + """ + registered_class = entry.registered_class + + # Extract description from docstring, clean up whitespace + doc = registered_class.__doc__ or "" + if doc: + description = " ".join(doc.split()) + else: + description = entry.description or "No description available" + + return RegistryItemMetadata( + name=name, + class_name=registered_class.__name__, + description=description, + ) + + def get_class(self, name: str) -> Type[T]: + """ + Get a registered class by name. + + Args: + name: The registry name (snake_case identifier). + + Returns: + The registered class (Type[T]). + Note: This returns the class itself, not an instance. + + Raises: + KeyError: If the name is not registered. + """ + self._ensure_discovered() + entry = self._class_entries.get(name) + if entry is None: + available = ", ".join(self.get_names()) + raise KeyError(f"'{name}' not found in registry. Available: {available}") + return entry.registered_class + + def get_entry(self, name: str) -> Optional[ClassEntry[T]]: + """ + Get the full ClassEntry for a registered class. + + This is useful when you need access to factory or default_kwargs. + + Args: + name: The registry name. + + Returns: + The ClassEntry containing class, factory, and defaults, or None if not found. + """ + self._ensure_discovered() + return self._class_entries.get(name) + + def get_names(self) -> List[str]: + """ + Get a sorted list of all registered names. + + These are the snake_case registry keys (e.g., "encoding", "self_ask_refusal"), + not the actual class names (e.g., "EncodingScenario", "SelfAskRefusalScorer"). + + Returns: + Sorted list of registry names. + """ + self._ensure_discovered() + return sorted(self._class_entries.keys()) + + def list_metadata( + self, + *, + include_filters: Optional[Dict[str, object]] = None, + exclude_filters: Optional[Dict[str, object]] = None, + ) -> List[MetadataT]: + """ + List metadata for all registered classes, optionally filtered. + + Supports filtering on any metadata property: + - Simple types (str, int, bool): exact match + - List types: checks if filter value is in the list + + Args: + include_filters: Optional dict of filters that items must match. + Keys are metadata property names, values are the filter criteria. + All filters must match (AND logic). + exclude_filters: Optional dict of filters that items must NOT match. + Keys are metadata property names, values are the filter criteria. + Any matching filter excludes the item. + + Returns: + List of metadata dictionaries (TypedDict) describing each registered class. + Note: This returns descriptive info, not the classes themselves. + """ + from pyrit.registry.base import _matches_filters + + self._ensure_discovered() + + if self._metadata_cache is None: + self._metadata_cache = [ + self._build_metadata(name, entry) for name, entry in sorted(self._class_entries.items()) + ] + + if not include_filters and not exclude_filters: + return self._metadata_cache + + return [ + m + for m in self._metadata_cache + if _matches_filters(m, include_filters=include_filters, exclude_filters=exclude_filters) + ] + + def register( + self, + cls: Type[T], + *, + name: Optional[str] = None, + factory: Optional[Callable[..., T]] = None, + default_kwargs: Optional[Dict[str, object]] = None, + description: Optional[str] = None, + ) -> None: + """ + Register a class with the registry. + + Args: + cls: The class to register (Type[T], not an instance). + name: Optional custom registry name. If not provided, derived from class name. + factory: Optional callable for creating instances with custom logic. + default_kwargs: Default keyword arguments for instance creation. + description: Optional description override. + """ + if name is None: + name = self._get_registry_name(cls) + + entry = ClassEntry( + registered_class=cls, + factory=factory, + default_kwargs=default_kwargs, + description=description, + ) + self._class_entries[name] = entry + self._metadata_cache = None + + def create_instance(self, name: str, **kwargs: object) -> T: + """ + Create an instance of a registered class. + + Args: + name: The registry name of the class. + **kwargs: Keyword arguments to pass to the factory or constructor. + + Returns: + A new instance of type T. + + Raises: + KeyError: If the name is not registered. + """ + self._ensure_discovered() + entry = self._class_entries.get(name) + if entry is None: + available = ", ".join(self.get_names()) + raise KeyError(f"'{name}' not found in registry. Available: {available}") + return entry.create_instance(**kwargs) + + def _get_registry_name(self, cls: Type[T]) -> str: + """ + Get the registry name for a class. + + Subclasses can override this to customize name derivation. + Default implementation converts CamelCase to snake_case. + + Args: + cls: The class to get a name for. + + Returns: + The registry name (snake_case identifier). + """ + return class_name_to_registry_name(cls.__name__) + + def __contains__(self, name: str) -> bool: + """ + Check if a name is registered. + + Returns: + True if the name is registered, False otherwise. + """ + self._ensure_discovered() + return name in self._class_entries + + def __len__(self) -> int: + """ + Get the count of registered classes. + + Returns: + The number of registered classes. + """ + self._ensure_discovered() + return len(self._class_entries) + + def __iter__(self) -> Iterator[str]: + """ + Iterate over registered names. + + Returns: + An iterator over sorted registered names. + """ + self._ensure_discovered() + return iter(sorted(self._class_entries.keys())) diff --git a/pyrit/registry/class_registries/initializer_registry.py b/pyrit/registry/class_registries/initializer_registry.py new file mode 100644 index 000000000..c2f283e3f --- /dev/null +++ b/pyrit/registry/class_registries/initializer_registry.py @@ -0,0 +1,290 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Initializer registry for discovering and cataloging PyRIT initializers. + +This module provides a unified registry for discovering all available +PyRITInitializer subclasses from the pyrit/setup/initializers directory structure. +""" + +from __future__ import annotations + +import importlib.util +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Dict, Optional + +from pyrit.registry.base import RegistryItemMetadata +from pyrit.registry.class_registries.base_class_registry import ( + BaseClassRegistry, + ClassEntry, +) +from pyrit.registry.discovery import discover_in_directory + +# Compute PYRIT_PATH directly to avoid importing pyrit package +# (which triggers heavy imports from __init__.py) +PYRIT_PATH = Path(__file__).parent.parent.parent.resolve() + +if TYPE_CHECKING: + from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class InitializerMetadata(RegistryItemMetadata): + """ + Metadata describing a registered PyRITInitializer class. + + Use get_class() to get the actual class. + """ + + initializer_name: str + required_env_vars: tuple[str, ...] + execution_order: int + + +class InitializerRegistry(BaseClassRegistry["PyRITInitializer", InitializerMetadata]): + """ + Registry for discovering and managing available initializers. + + This class discovers all PyRITInitializer subclasses from the + pyrit/setup/initializers directory structure. + + Initializers are identified by their filename (e.g., "objective_target", "simple"). + The directory structure is used for organization but not exposed to users. + """ + + @classmethod + def get_registry_singleton(cls) -> "InitializerRegistry": + """ + Get the singleton instance of the InitializerRegistry. + + Returns: + The singleton InitializerRegistry instance. + """ + return super().get_registry_singleton() # type: ignore[return-value] + + def __init__(self, *, discovery_path: Optional[Path] = None, lazy_discovery: bool = False) -> None: + """ + Initialize the initializer registry. + + Args: + discovery_path: The path to discover initializers from. + If None, defaults to pyrit/setup/initializers (discovers all). + To discover only scenarios, pass pyrit/setup/initializers/scenarios. + lazy_discovery: If True, discovery is deferred until first access. + Defaults to False for backwards compatibility. + """ + self._discovery_path = discovery_path + if self._discovery_path is None: + self._discovery_path = Path(PYRIT_PATH) / "setup" / "initializers" + + # At this point _discovery_path is guaranteed to be a Path + assert self._discovery_path is not None + + # Track file paths for collision detection and resolution + self._initializer_paths: Dict[str, Path] = {} + + super().__init__(lazy_discovery=lazy_discovery) + + def _discover(self) -> None: + """Discover all initializers from the specified discovery path.""" + discovery_path = self._discovery_path + assert discovery_path is not None # Set in __init__ + + if not discovery_path.exists(): + logger.warning(f"Initializers directory not found: {discovery_path}") + return + + # Import base class for discovery + from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer + + if discovery_path.is_file(): + self._process_file(file_path=discovery_path, base_class=PyRITInitializer) # type: ignore[type-abstract] + else: + for file_stem, file_path, initializer_class in discover_in_directory( + directory=discovery_path, + base_class=PyRITInitializer, # type: ignore[type-abstract] + recursive=True, + ): + self._register_initializer( + short_name=file_stem, + file_path=file_path, + initializer_class=initializer_class, + ) + + def _process_file(self, *, file_path: Path, base_class: type) -> None: + """ + Process a Python file to extract initializer subclasses. + + Args: + file_path: Path to the Python file to process. + base_class: The PyRITInitializer base class. + """ + import inspect + + short_name = file_path.stem + + # Check for name collision + if short_name in self._initializer_paths: + existing_path = self._initializer_paths[short_name] + logger.error( + f"Initializer name collision: '{short_name}' found in both " + f"'{file_path}' and '{existing_path}'. " + f"Initializer filenames must be unique across all directories." + ) + return + + try: + spec = importlib.util.spec_from_file_location(f"initializer.{short_name}", file_path) + if not spec or not spec.loader: + return + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + for attr_name in dir(module): + attr = getattr(module, attr_name) + if inspect.isclass(attr) and issubclass(attr, base_class) and attr is not base_class: + if not inspect.isabstract(attr): + self._register_initializer( + short_name=short_name, + file_path=file_path, + initializer_class=attr, # type: ignore[arg-type] + ) + + except Exception as e: + logger.warning(f"Failed to load initializer module {short_name}: {e}") + + def _register_initializer( + self, + *, + short_name: str, + file_path: Path, + initializer_class: "type[PyRITInitializer]", + ) -> None: + """ + Register an initializer class. + + Args: + short_name: The short name for the initializer (filename without extension). + file_path: The path to the file containing the initializer. + initializer_class: The initializer class to register. + """ + # Check for name collision + if short_name in self._initializer_paths: + existing_path = self._initializer_paths[short_name] + logger.error( + f"Initializer name collision: '{short_name}' found in both '{file_path}' and '{existing_path}'." + ) + return + + try: + # Create the entry + entry = ClassEntry(registered_class=initializer_class) + self._class_entries[short_name] = entry + self._initializer_paths[short_name] = file_path + logger.debug(f"Registered initializer: {short_name} ({initializer_class.__name__})") + + except Exception as e: + logger.warning(f"Failed to register initializer {initializer_class.__name__}: {e}") + + def _build_metadata(self, name: str, entry: ClassEntry["PyRITInitializer"]) -> InitializerMetadata: + """ + Build metadata for an initializer class. + + Args: + name: The registry name of the initializer. + entry: The ClassEntry containing the initializer class. + + Returns: + InitializerMetadata describing the initializer class. + """ + initializer_class = entry.registered_class + + try: + instance = initializer_class() + return InitializerMetadata( + name=name, + class_name=initializer_class.__name__, + description=instance.description, + initializer_name=instance.name, + required_env_vars=tuple(instance.required_env_vars), + execution_order=instance.execution_order, + ) + except Exception as e: + logger.warning(f"Failed to get metadata for {name}: {e}") + return InitializerMetadata( + name=name, + class_name=initializer_class.__name__, + description="Error loading initializer metadata", + initializer_name=name, + required_env_vars=(), + execution_order=100, + ) + + def resolve_initializer_paths(self, *, initializer_names: list[str]) -> list[Path]: + """ + Resolve initializer names to their file paths. + + Args: + initializer_names: List of initializer names to resolve. + + Returns: + List of resolved file paths. + + Raises: + ValueError: If any initializer name is not found or has no file path. + """ + self._ensure_discovered() + resolved_paths = [] + + for initializer_name in initializer_names: + if initializer_name not in self._class_entries: + available = ", ".join(sorted(self.get_names())) + raise ValueError( + f"Built-in initializer '{initializer_name}' not found.\n" + f"Available initializers: {available}\n" + f"Use 'pyrit_scan --list-initializers' to see detailed information." + ) + + initializer_file = self._initializer_paths.get(initializer_name) + if initializer_file is None: + raise ValueError(f"Could not locate file for initializer '{initializer_name}'.") + + resolved_paths.append(initializer_file) + + return resolved_paths + + @staticmethod + def resolve_script_paths(*, script_paths: list[str]) -> list[Path]: + """ + Resolve and validate custom script paths. + + Args: + script_paths: List of script path strings to resolve. + + Returns: + List of resolved Path objects. + + Raises: + FileNotFoundError: If any script path does not exist. + """ + resolved_paths = [] + + for script in script_paths: + script_path = Path(script) + if not script_path.is_absolute(): + script_path = Path.cwd() / script_path + + if not script_path.exists(): + raise FileNotFoundError( + f"Initialization script not found: {script_path}\n Looked in: {script_path.absolute()}" + ) + + resolved_paths.append(script_path) + + return resolved_paths diff --git a/pyrit/registry/class_registries/scenario_registry.py b/pyrit/registry/class_registries/scenario_registry.py new file mode 100644 index 000000000..40693f0cd --- /dev/null +++ b/pyrit/registry/class_registries/scenario_registry.py @@ -0,0 +1,181 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scenario registry for discovering and managing PyRIT scenarios. + +This module provides a unified registry for discovering all available Scenario subclasses +from the pyrit.scenario.scenarios module and from user-defined initialization scripts. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Optional + +from pyrit.registry.base import RegistryItemMetadata +from pyrit.registry.class_registries.base_class_registry import ( + BaseClassRegistry, + ClassEntry, +) +from pyrit.registry.discovery import ( + discover_in_package, + discover_subclasses_in_loaded_modules, +) +from pyrit.registry.name_utils import class_name_to_registry_name + +if TYPE_CHECKING: + from pyrit.scenario.core import Scenario + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ScenarioMetadata(RegistryItemMetadata): + """ + Metadata describing a registered Scenario class. + + Use get_class() to get the actual class. + """ + + default_strategy: str + all_strategies: tuple[str, ...] + aggregate_strategies: tuple[str, ...] + default_datasets: tuple[str, ...] + max_dataset_size: Optional[int] + + +class ScenarioRegistry(BaseClassRegistry["Scenario", ScenarioMetadata]): + """ + Registry for discovering and managing available scenario classes. + + This class discovers all Scenario subclasses from: + 1. Built-in scenarios in pyrit.scenario.scenarios module + 2. User-defined scenarios from initialization scripts (set via globals) + + Scenarios are identified by their simple name (e.g., "encoding", "foundry"). + """ + + @classmethod + def get_registry_singleton(cls) -> "ScenarioRegistry": + """ + Get the singleton instance of the ScenarioRegistry. + + Returns: + The singleton ScenarioRegistry instance. + """ + return super().get_registry_singleton() # type: ignore[return-value] + + def __init__(self, *, lazy_discovery: bool = True) -> None: + """ + Initialize the scenario registry. + + Args: + lazy_discovery: If True, discovery is deferred until first access. + Defaults to True for performance. + """ + super().__init__(lazy_discovery=lazy_discovery) + + def _discover(self) -> None: + """Discover all built-in scenarios from pyrit.scenario.scenarios module.""" + self._discover_builtin_scenarios() + + def _discover_builtin_scenarios(self) -> None: + """ + Discover all built-in scenarios from pyrit.scenario.scenarios module. + + This method dynamically imports all modules in the scenarios package + and registers any Scenario subclasses found. + """ + from pyrit.scenario.core import Scenario + + try: + import pyrit.scenario.scenarios as scenarios_package + + # Get the path to the scenarios package + package_file = scenarios_package.__file__ + if package_file is None: + if hasattr(scenarios_package, "__path__"): + package_path = Path(scenarios_package.__path__[0]) + else: + logger.error("Cannot determine scenarios package location") + return + else: + package_path = Path(package_file).parent + + # Discover scenarios using the shared discovery utility + for module_name, scenario_class in discover_in_package( + package_path=package_path, + package_name="pyrit.scenario.scenarios", + base_class=Scenario, # type: ignore[type-abstract] + recursive=True, + ): + entry = ClassEntry(registered_class=scenario_class) + self._class_entries[module_name] = entry + logger.debug(f"Registered built-in scenario: {module_name} ({scenario_class.__name__})") + + except Exception as e: + logger.error(f"Failed to discover built-in scenarios: {e}") + + def discover_user_scenarios(self) -> None: + """ + Discover user-defined scenarios from global variables. + + After initialization scripts are executed, they may define Scenario subclasses + and store them in globals. This method searches for such classes. + + User scenarios will override built-in scenarios with the same name. + """ + from pyrit.scenario.core import Scenario + + try: + for module_name, scenario_class in discover_subclasses_in_loaded_modules( + base_class=Scenario # type: ignore[type-abstract] + ): + # Check if this is a user-defined class (not from pyrit.scenario.scenarios) + if not scenario_class.__module__.startswith("pyrit.scenario.scenarios"): + # Convert class name to snake_case for scenario name + registry_name = class_name_to_registry_name(scenario_class.__name__, suffix="Scenario") + entry = ClassEntry(registered_class=scenario_class) + self._class_entries[registry_name] = entry + logger.info(f"Registered user-defined scenario: {registry_name} ({scenario_class.__name__})") + + except Exception as e: + logger.debug(f"Failed to discover user scenarios: {e}") + + def _build_metadata(self, name: str, entry: ClassEntry["Scenario"]) -> ScenarioMetadata: + """ + Build metadata for a Scenario class. + + Args: + name: The registry name of the scenario. + entry: The ClassEntry containing the scenario class. + + Returns: + ScenarioMetadata describing the scenario class. + """ + scenario_class = entry.registered_class + + # Extract description from docstring, clean up whitespace + doc = scenario_class.__doc__ or "" + description = " ".join(doc.split()) if doc else entry.description or "No description available" + + # Get the strategy class for this scenario + strategy_class = scenario_class.get_strategy_class() + + dataset_config = scenario_class.default_dataset_config() + default_datasets = dataset_config.get_default_dataset_names() + max_dataset_size = dataset_config.max_dataset_size + + return ScenarioMetadata( + name=name, + class_name=scenario_class.__name__, + description=description, + default_strategy=scenario_class.get_default_strategy().value, + all_strategies=tuple(s.value for s in strategy_class.get_all_strategies()), + aggregate_strategies=tuple(s.value for s in strategy_class.get_aggregate_strategies()), + default_datasets=tuple(default_datasets), + max_dataset_size=max_dataset_size, + ) diff --git a/pyrit/registry/discovery.py b/pyrit/registry/discovery.py new file mode 100644 index 000000000..86b7a4a71 --- /dev/null +++ b/pyrit/registry/discovery.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Discovery utilities for PyRIT registries. + +This module provides functions for discovering classes in directories and packages, +used by registries to find and register items automatically. +""" + +import importlib +import importlib.util +import inspect +import logging +import pkgutil +from pathlib import Path +from typing import Callable, Iterator, Optional, Tuple, Type, TypeVar + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + + +def discover_in_directory( + *, + directory: Path, + base_class: Type[T], + recursive: bool = True, +) -> Iterator[Tuple[str, Path, Type[T]]]: + """ + Discover all subclasses of base_class in a directory by loading Python files. + + This function walks a directory, loads Python files dynamically, and yields + any classes that are subclasses of the specified base_class. + + Args: + directory: The directory to search for Python files. + base_class: The base class to filter subclasses of. + recursive: Whether to recursively search subdirectories. Defaults to True. + + Yields: + Tuples of (filename_stem, file_path, class) for each discovered subclass. + """ + if not directory.exists(): + logger.warning(f"Discovery directory not found: {directory}") + return + + for item in directory.iterdir(): + if item.is_file() and item.suffix == ".py" and item.stem != "__init__": + yield from _process_file(file_path=item, base_class=base_class) + elif recursive and item.is_dir() and item.name != "__pycache__": + yield from discover_in_directory(directory=item, base_class=base_class, recursive=True) + + +def _process_file(*, file_path: Path, base_class: Type[T]) -> Iterator[Tuple[str, Path, Type[T]]]: + """ + Process a Python file and yield subclasses of the base class. + + Args: + file_path: Path to the Python file to process. + base_class: The base class to filter subclasses of. + + Yields: + Tuples of (filename_stem, file_path, class) for each discovered subclass. + """ + try: + spec = importlib.util.spec_from_file_location(f"discovered_module.{file_path.stem}", file_path) + if not spec or not spec.loader: + return + + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + for attr_name in dir(module): + attr = getattr(module, attr_name) + if inspect.isclass(attr) and issubclass(attr, base_class) and attr is not base_class: + # Check it's not abstract + if not inspect.isabstract(attr): + yield (file_path.stem, file_path, attr) + + except Exception as e: + logger.warning(f"Failed to load module from {file_path}: {e}") + + +def discover_in_package( + *, + package_path: Path, + package_name: str, + base_class: Type[T], + recursive: bool = True, + name_builder: Optional[Callable[[str, str], str]] = None, +) -> Iterator[Tuple[str, Type[T]]]: + """ + Discover all subclasses using pkgutil.iter_modules on a package. + + This function uses Python's package infrastructure to discover modules, + making it suitable for discovering classes in installed packages. + + Args: + package_path: The filesystem path to the package directory. + package_name: The dotted module name of the package (e.g., "pyrit.scenario.scenarios"). + base_class: The base class to filter subclasses of. + recursive: Whether to recursively search subpackages. Defaults to True. + name_builder: Optional callable to build the registry name from (prefix, module_name). + Defaults to returning just the module_name. + + Yields: + Tuples of (registry_name, class) for each discovered subclass. + """ + if name_builder is None: + name_builder = lambda prefix, name: name if not prefix else f"{prefix}.{name}" + + for _, module_name, is_pkg in pkgutil.iter_modules([str(package_path)]): + if module_name.startswith("_"): + continue + + full_module_name = f"{package_name}.{module_name}" + + try: + module = importlib.import_module(full_module_name) + + # For non-package modules, find and yield subclasses + if not is_pkg: + for name, obj in inspect.getmembers(module, inspect.isclass): + if issubclass(obj, base_class) and obj is not base_class: + if not inspect.isabstract(obj): + yield (module_name, obj) + + # Recursively discover in subpackages + if recursive and is_pkg: + sub_path = package_path / module_name + yield from discover_in_package( + package_path=sub_path, + package_name=full_module_name, + base_class=base_class, + recursive=True, + name_builder=name_builder, + ) + + except Exception as e: + logger.warning(f"Failed to load package module {full_module_name}: {e}") + + +def discover_subclasses_in_loaded_modules( + *, + base_class: Type[T], + exclude_module_prefixes: Optional[Tuple[str, ...]] = None, +) -> Iterator[Tuple[str, Type[T]]]: + """ + Discover subclasses of a base class from already-loaded modules. + + This is useful for discovering user-defined classes that were loaded + via initialization scripts or dynamic imports. + + Args: + base_class: The base class to filter subclasses of. + exclude_module_prefixes: Module prefixes to exclude from search. + Defaults to common system modules. + + Yields: + Tuples of (module_name, class) for each discovered subclass. + """ + import sys + + if exclude_module_prefixes is None: + exclude_module_prefixes = ("builtins", "_", "sys", "os", "importlib") + + # Create a snapshot to avoid dictionary changed size during iteration + modules_snapshot = list(sys.modules.items()) + + for module_name, module in modules_snapshot: + if module is None or not hasattr(module, "__dict__"): + continue + + # Skip excluded modules + if any(module_name.startswith(prefix) for prefix in exclude_module_prefixes): + continue + + for name, obj in inspect.getmembers(module, inspect.isclass): + if issubclass(obj, base_class) and obj is not base_class: + if not inspect.isabstract(obj): + yield (module_name, obj) diff --git a/pyrit/registry/instance_registries/__init__.py b/pyrit/registry/instance_registries/__init__.py new file mode 100644 index 000000000..b2b1fad0f --- /dev/null +++ b/pyrit/registry/instance_registries/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Instance registries package. + +This package contains registries that store pre-configured instances (not classes). +Examples include ScorerRegistry which stores Scorer instances that have been +initialized with their required parameters (e.g., chat_target). + +For registries that store classes (Type[T]), see class_registries/. +""" + +from pyrit.registry.instance_registries.base_instance_registry import ( + BaseInstanceRegistry, +) +from pyrit.registry.instance_registries.scorer_registry import ( + ScorerMetadata, + ScorerRegistry, +) + +__all__ = [ + # Base class + "BaseInstanceRegistry", + # Concrete registries + "ScorerRegistry", + "ScorerMetadata", +] diff --git a/pyrit/registry/instance_registries/base_instance_registry.py b/pyrit/registry/instance_registries/base_instance_registry.py new file mode 100644 index 000000000..b0b7eb160 --- /dev/null +++ b/pyrit/registry/instance_registries/base_instance_registry.py @@ -0,0 +1,194 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Base instance registry for PyRIT. + +This module provides the abstract base class for registries that store +pre-configured instances (not classes). Unlike class registries which +store Type[T] and create instances on demand, instance registries store +already-instantiated objects. + +Examples include: +- ScorerRegistry: stores Scorer instances configured with their chat_target +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Generic, Iterator, List, Optional, TypeVar + +from pyrit.registry.base import RegistryItemMetadata, RegistryProtocol + +T = TypeVar("T") # The type of instances stored +MetadataT = TypeVar("MetadataT", bound=RegistryItemMetadata) + + +class BaseInstanceRegistry(ABC, RegistryProtocol[MetadataT], Generic[T, MetadataT]): + """ + Abstract base class for registries that store pre-configured instances. + + This class implements RegistryProtocol. Unlike BaseClassRegistry which stores + Type[T] and supports lazy discovery, instance registries store already-instantiated + objects that are registered explicitly (typically during initialization). + + Type Parameters: + T: The type of instances stored in the registry. + MetadataT: A TypedDict subclass for instance metadata. + + Subclasses must implement: + - _build_metadata(): Convert an instance to its metadata representation + """ + + # Class-level singleton instances, keyed by registry class + _instances: Dict[type, "BaseInstanceRegistry[Any, Any]"] = {} + + @classmethod + def get_registry_singleton(cls) -> "BaseInstanceRegistry[T, MetadataT]": + """ + Get the singleton instance of this registry. + + Creates the instance on first call with default parameters. + + Returns: + The singleton instance of this registry class. + """ + if cls not in cls._instances: + cls._instances[cls] = cls() # type: ignore[assignment] + return cls._instances[cls] # type: ignore[return-value] + + @classmethod + def reset_instance(cls) -> None: + """ + Reset the singleton instance. + + Useful for testing or reinitializing the registry. + """ + if cls in cls._instances: + del cls._instances[cls] + + def __init__(self) -> None: + """Initialize the instance registry.""" + # Maps registry names to registered items + self._registry_items: Dict[str, T] = {} + self._metadata_cache: Optional[List[MetadataT]] = None + + def register( + self, + instance: T, + *, + name: str, + ) -> None: + """ + Register an instance. + + Args: + instance: The pre-configured instance to register. + name: The registry name for this instance. + """ + self._registry_items[name] = instance + self._metadata_cache = None + + def get(self, name: str) -> Optional[T]: + """ + Get a registered instance by name. + + Args: + name: The registry name of the instance. + + Returns: + The instance, or None if not found. + """ + return self._registry_items.get(name) + + def get_names(self) -> List[str]: + """ + Get a sorted list of all registered names. + + Returns: + Sorted list of registry names (keys). + """ + return sorted(self._registry_items.keys()) + + def list_metadata( + self, + *, + include_filters: Optional[Dict[str, object]] = None, + exclude_filters: Optional[Dict[str, object]] = None, + ) -> List[MetadataT]: + """ + List metadata for all registered instances, optionally filtered. + + Supports filtering on any metadata property: + - Simple types (str, int, bool): exact match + - List types: checks if filter value is in the list + + Args: + include_filters: Optional dict of filters that items must match. + Keys are metadata property names, values are the filter criteria. + All filters must match (AND logic). + exclude_filters: Optional dict of filters that items must NOT match. + Keys are metadata property names, values are the filter criteria. + Any matching filter excludes the item. + + Returns: + List of metadata dictionaries describing each registered instance. + """ + from pyrit.registry.base import _matches_filters + + if self._metadata_cache is None: + items = [] + for name in sorted(self._registry_items.keys()): + instance = self._registry_items[name] + items.append(self._build_metadata(name, instance)) + self._metadata_cache = items + + if not include_filters and not exclude_filters: + return self._metadata_cache + + return [ + m + for m in self._metadata_cache + if _matches_filters(m, include_filters=include_filters, exclude_filters=exclude_filters) + ] + + @abstractmethod + def _build_metadata(self, name: str, instance: T) -> MetadataT: + """ + Build metadata for an instance. + + Args: + name: The registry name of the instance. + instance: The instance. + + Returns: + A metadata dictionary describing the instance. + """ + ... + + def __contains__(self, name: str) -> bool: + """ + Check if a name is registered. + + Returns: + True if the name is registered, False otherwise. + """ + return name in self._registry_items + + def __len__(self) -> int: + """ + Get the count of registered instances. + + Returns: + The number of registered instances. + """ + return len(self._registry_items) + + def __iter__(self) -> Iterator[str]: + """ + Iterate over registered names. + + Returns: + An iterator over sorted registered names. + """ + return iter(sorted(self._registry_items.keys())) diff --git a/pyrit/registry/instance_registries/scorer_registry.py b/pyrit/registry/instance_registries/scorer_registry.py new file mode 100644 index 000000000..7a9411a8b --- /dev/null +++ b/pyrit/registry/instance_registries/scorer_registry.py @@ -0,0 +1,138 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Scorer registry for discovering and managing PyRIT scorers. + +scorers are registered explicitly via initializers as pre-configured instances. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass +from typing import TYPE_CHECKING, Optional + +from pyrit.registry.base import RegistryItemMetadata +from pyrit.registry.instance_registries.base_instance_registry import ( + BaseInstanceRegistry, +) +from pyrit.registry.name_utils import class_name_to_registry_name +from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer +from pyrit.score.true_false.true_false_scorer import TrueFalseScorer + +if TYPE_CHECKING: + from pyrit.score.scorer import Scorer + from pyrit.score.scorer_identifier import ScorerIdentifier + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class ScorerMetadata(RegistryItemMetadata): + """ + Metadata describing a registered scorer instance. + + Unlike ScenarioMetadata/InitializerMetadata which describe classes, + ScorerMetadata describes an already-instantiated scorer. + + Use get() to retrieve the actual scorer instance. + """ + + scorer_type: str + scorer_identifier: "ScorerIdentifier" + + +class ScorerRegistry(BaseInstanceRegistry["Scorer", ScorerMetadata]): + """ + Registry for managing available scorer instances. + + This registry stores pre-configured Scorer instances (not classes). + Scorers are registered explicitly via initializers after being instantiated + with their required parameters (e.g., chat_target). + + Scorers are identified by their snake_case name derived from the class name, + or a custom name provided during registration. + """ + + @classmethod + def get_registry_singleton(cls) -> "ScorerRegistry": + """ + Get the singleton instance of the ScorerRegistry. + + Returns: + The singleton ScorerRegistry instance. + """ + return super().get_registry_singleton() # type: ignore[return-value] + + def register_instance( + self, + scorer: "Scorer", + *, + name: Optional[str] = None, + ) -> None: + """ + Register a scorer instance. + + Note: Unlike ScenarioRegistry and InitializerRegistry which register classes, + ScorerRegistry registers pre-configured instances. + + Args: + scorer: The pre-configured scorer instance (not a class). + name: Optional custom registry name. If not provided, + derived from class name with scorer_identifier hash appended + (e.g., SelfAskRefusalScorer -> self_ask_refusal_abc123). + """ + if name is None: + base_name = class_name_to_registry_name(scorer.__class__.__name__, suffix="Scorer") + # Append scorer_identifier hash if available for uniqueness + identifier_hash = scorer.scorer_identifier.compute_hash()[:8] + name = f"{base_name}_{identifier_hash}" + + self.register(scorer, name=name) + logger.debug(f"Registered scorer instance: {name} ({scorer.__class__.__name__})") + + def get_instance_by_name(self, name: str) -> Optional["Scorer"]: + """ + Get a registered scorer instance by name. + + Note: This returns an already-instantiated scorer, not a class. + + Args: + name: The registry name of the scorer. + + Returns: + The scorer instance, or None if not found. + """ + return self.get(name) + + def _build_metadata(self, name: str, instance: "Scorer") -> ScorerMetadata: + """ + Build metadata for a scorer instance. + + Args: + name: The registry name of the scorer. + instance: The scorer instance. + + Returns: + ScorerMetadata dictionary describing the scorer. + """ + # Get description from docstring + doc = instance.__class__.__doc__ or "" + description = " ".join(doc.split()) if doc else "No description available" + + # Determine scorer_type from class hierarchy + if isinstance(instance, TrueFalseScorer): + scorer_type = "true_false" + elif isinstance(instance, FloatScaleScorer): + scorer_type = "float_scale" + else: + scorer_type = "unknown" + + return ScorerMetadata( + name=name, + class_name=instance.__class__.__name__, + description=description, + scorer_type=scorer_type, + scorer_identifier=instance.scorer_identifier, + ) diff --git a/pyrit/registry/name_utils.py b/pyrit/registry/name_utils.py new file mode 100644 index 000000000..ae320a532 --- /dev/null +++ b/pyrit/registry/name_utils.py @@ -0,0 +1,59 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +""" +Name conversion utilities for PyRIT registries. + +This module provides functions for converting between different naming conventions +used in registries (e.g., CamelCase class names to snake_case registry names). +""" + +import re + + +def class_name_to_registry_name(class_name: str, *, suffix: str = "") -> str: + """ + Convert a CamelCase class name to a snake_case registry name. + + Args: + class_name: The class name to convert (e.g., "MyCustomScenario"). + suffix: Optional suffix to strip from the class name before conversion + (e.g., "Scenario" would convert "MyCustomScenario" to "my_custom"). + + Returns: + The snake_case registry name (e.g., "my_custom_scenario" or "my_custom"). + """ + # Remove suffix if present + if suffix and class_name.endswith(suffix): + class_name = class_name[: -len(suffix)] + + # Convert CamelCase to snake_case + # First, handle transitions like "XMLParser" -> "XML_Parser" + name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", class_name) + # Then handle transitions like "getHTTPResponse" -> "get_HTTP_Response" + name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() + + return name + + +def registry_name_to_class_name(registry_name: str, *, suffix: str = "") -> str: + """ + Convert a snake_case registry name to a PascalCase class name. + + Args: + registry_name: The registry name to convert (e.g., "my_custom"). + suffix: Optional suffix to append to the class name + (e.g., "Scenario" would convert "my_custom" to "MyCustomScenario"). + + Returns: + The PascalCase class name (e.g., "MyCustomScenario"). + """ + # Split on underscores and capitalize each part + parts = registry_name.split("_") + pascal_case = "".join(part.capitalize() for part in parts) + + # Append suffix if provided + if suffix: + pascal_case += suffix + + return pascal_case diff --git a/pyrit/setup/initializers/scenarios/load_default_datasets.py b/pyrit/setup/initializers/scenarios/load_default_datasets.py index 0d23e2973..cca17dba6 100644 --- a/pyrit/setup/initializers/scenarios/load_default_datasets.py +++ b/pyrit/setup/initializers/scenarios/load_default_datasets.py @@ -12,9 +12,9 @@ import textwrap from typing import List -from pyrit.cli.scenario_registry import ScenarioRegistry from pyrit.datasets import SeedDatasetProvider from pyrit.memory import CentralMemory +from pyrit.registry import ScenarioRegistry from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer logger = logging.getLogger(__name__) @@ -54,16 +54,16 @@ def required_env_vars(self) -> List[str]: async def initialize_async(self) -> None: """Load default datasets from all registered scenarios.""" # Get ScenarioRegistry to discover all scenarios - registry = ScenarioRegistry() + registry = ScenarioRegistry.get_registry_singleton() # Collect all default datasets from all scenarios all_default_datasets: List[str] = [] # Get all scenario names from registry - scenario_names = registry.get_scenario_names() + scenario_names = registry.get_names() for scenario_name in scenario_names: - scenario_class = registry.get_scenario(scenario_name) + scenario_class = registry.get_class(scenario_name) if scenario_class: # Get default_dataset_config from the scenario class try: diff --git a/tests/end_to_end/test_scenarios.py b/tests/end_to_end/test_scenarios.py index 675fc7539..4e26e216f 100644 --- a/tests/end_to_end/test_scenarios.py +++ b/tests/end_to_end/test_scenarios.py @@ -11,7 +11,7 @@ import pytest from pyrit.cli.pyrit_scan import main as pyrit_scan_main -from pyrit.cli.scenario_registry import ScenarioRegistry +from pyrit.registry import ScenarioRegistry def get_all_scenarios(): @@ -21,8 +21,8 @@ def get_all_scenarios(): Returns: List[str]: Sorted list of scenario names. """ - registry = ScenarioRegistry() - return registry.get_scenario_names() + registry = ScenarioRegistry.get_registry_singleton() + return registry.get_names() @pytest.mark.timeout(7200) # 2 hour timeout per scenario diff --git a/tests/unit/cli/test_frontend_core.py b/tests/unit/cli/test_frontend_core.py index 4fd09cba7..305f249a9 100644 --- a/tests/unit/cli/test_frontend_core.py +++ b/tests/unit/cli/test_frontend_core.py @@ -11,6 +11,7 @@ import pytest from pyrit.cli import frontend_core +from pyrit.registry import InitializerMetadata, ScenarioMetadata class TestFrontendCore: @@ -53,8 +54,8 @@ def test_init_with_invalid_log_level(self): with pytest.raises(ValueError, match="Invalid log level"): frontend_core.FrontendCore(log_level="INVALID") - @patch("pyrit.cli.scenario_registry.ScenarioRegistry") - @patch("pyrit.cli.initializer_registry.InitializerRegistry") + @patch("pyrit.registry.ScenarioRegistry") + @patch("pyrit.registry.InitializerRegistry") @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) def test_initialize_loads_registries( self, @@ -70,11 +71,11 @@ def test_initialize_loads_registries( assert context._initialized is True mock_init_pyrit.assert_called_once() - mock_scenario_registry.assert_called_once() + mock_scenario_registry.get_instance.assert_called_once() mock_init_registry.assert_called_once() - @patch("pyrit.cli.scenario_registry.ScenarioRegistry") - @patch("pyrit.cli.initializer_registry.InitializerRegistry") + @patch("pyrit.registry.ScenarioRegistry") + @patch("pyrit.registry.InitializerRegistry") @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) async def test_scenario_registry_property_initializes( self, @@ -92,8 +93,8 @@ async def test_scenario_registry_property_initializes( assert context._initialized is True assert registry is not None - @patch("pyrit.cli.scenario_registry.ScenarioRegistry") - @patch("pyrit.cli.initializer_registry.InitializerRegistry") + @patch("pyrit.registry.ScenarioRegistry") + @patch("pyrit.registry.InitializerRegistry") @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) async def test_initializer_registry_property_initializes( self, @@ -249,7 +250,7 @@ def test_parse_memory_labels_non_string_key(self): class TestResolveInitializationScripts: """Tests for resolve_initialization_scripts function.""" - @patch("pyrit.cli.initializer_registry.InitializerRegistry.resolve_script_paths") + @patch("pyrit.registry.InitializerRegistry.resolve_script_paths") def test_resolve_initialization_scripts(self, mock_resolve: MagicMock): """Test resolve_initialization_scripts calls InitializerRegistry.""" mock_resolve.return_value = [Path("/test/script.py")] @@ -277,7 +278,7 @@ class TestListFunctions: async def test_list_scenarios(self): """Test list_scenarios_async returns scenarios from registry.""" mock_registry = MagicMock() - mock_registry.list_scenarios.return_value = [{"name": "test_scenario"}] + mock_registry.list_metadata.return_value = [{"name": "test_scenario"}] context = frontend_core.FrontendCore() context._scenario_registry = mock_registry @@ -286,12 +287,12 @@ async def test_list_scenarios(self): result = await frontend_core.list_scenarios_async(context=context) assert result == [{"name": "test_scenario"}] - mock_registry.list_scenarios.assert_called_once() + mock_registry.list_metadata.assert_called_once() async def test_list_initializers_without_discovery_path(self): """Test list_initializers_async without discovery path.""" mock_registry = MagicMock() - mock_registry.list_initializers.return_value = [{"name": "test_init"}] + mock_registry.list_metadata.return_value = [{"name": "test_init"}] context = frontend_core.FrontendCore() context._initializer_registry = mock_registry @@ -300,13 +301,13 @@ async def test_list_initializers_without_discovery_path(self): result = await frontend_core.list_initializers_async(context=context) assert result == [{"name": "test_init"}] - mock_registry.list_initializers.assert_called_once() + mock_registry.list_metadata.assert_called_once() - @patch("pyrit.cli.initializer_registry.InitializerRegistry") + @patch("pyrit.registry.InitializerRegistry") async def test_list_initializers_with_discovery_path(self, mock_init_registry_class: MagicMock): """Test list_initializers_async with discovery path.""" mock_registry = MagicMock() - mock_registry.list_initializers.return_value = [{"name": "custom_init"}] + mock_registry.list_metadata.return_value = [{"name": "custom_init"}] mock_init_registry_class.return_value = mock_registry context = frontend_core.FrontendCore() @@ -325,12 +326,17 @@ async def test_print_scenarios_list_with_scenarios(self, capsys): """Test print_scenarios_list with scenarios.""" context = frontend_core.FrontendCore() mock_registry = MagicMock() - mock_registry.list_scenarios.return_value = [ - { - "name": "test_scenario", - "class_name": "TestScenario", - "description": "Test description", - } + mock_registry.list_metadata.return_value = [ + ScenarioMetadata( + name="test_scenario", + class_name="TestScenario", + description="Test description", + default_strategy="default", + all_strategies=(), + aggregate_strategies=(), + default_datasets=(), + max_dataset_size=None, + ) ] context._scenario_registry = mock_registry context._initialized = True @@ -346,7 +352,7 @@ async def test_print_scenarios_list_empty(self, capsys): """Test print_scenarios_list with no scenarios.""" context = frontend_core.FrontendCore() mock_registry = MagicMock() - mock_registry.list_scenarios.return_value = [] + mock_registry.list_metadata.return_value = [] context._scenario_registry = mock_registry context._initialized = True @@ -360,13 +366,15 @@ async def test_print_initializers_list_with_initializers(self, capsys): """Test print_initializers_list_async with initializers.""" context = frontend_core.FrontendCore() mock_registry = MagicMock() - mock_registry.list_initializers.return_value = [ - { - "name": "test_init", - "class_name": "TestInit", - "initializer_name": "test", - "execution_order": 100, - } + mock_registry.list_metadata.return_value = [ + InitializerMetadata( + name="test_init", + class_name="TestInit", + description="Test initializer", + initializer_name="test", + execution_order=100, + required_env_vars=(), + ) ] context._initializer_registry = mock_registry context._initialized = True @@ -382,7 +390,7 @@ async def test_print_initializers_list_empty(self, capsys): """Test print_initializers_list_async with no initializers.""" context = frontend_core.FrontendCore() mock_registry = MagicMock() - mock_registry.list_initializers.return_value = [] + mock_registry.list_metadata.return_value = [] context._initializer_registry = mock_registry context._initialized = True @@ -394,103 +402,114 @@ async def test_print_initializers_list_empty(self, capsys): class TestFormatFunctions: - """Tests for format_scenario_info and format_initializer_info.""" - - def test_format_scenario_info_basic(self, capsys): - """Test format_scenario_info with basic info.""" - scenario_info = { - "name": "test_scenario", - "class_name": "TestScenario", - } + """Tests for format_scenario_metadata and format_initializer_metadata.""" + + def test_format_scenario_metadata_basic(self, capsys): + """Test format_scenario_metadata with basic metadata.""" + + scenario_metadata = ScenarioMetadata( + name="test_scenario", + class_name="TestScenario", + description="", + default_strategy="", + all_strategies=(), + aggregate_strategies=(), + default_datasets=(), + max_dataset_size=None, + ) - frontend_core.format_scenario_info(scenario_info=scenario_info) + frontend_core.format_scenario_metadata(scenario_metadata=scenario_metadata) captured = capsys.readouterr() assert "test_scenario" in captured.out assert "TestScenario" in captured.out - def test_format_scenario_info_with_description(self, capsys): - """Test format_scenario_info with description.""" - scenario_info = { - "name": "test_scenario", - "class_name": "TestScenario", - "description": "This is a test scenario", - } + def test_format_scenario_metadata_with_description(self, capsys): + """Test format_scenario_metadata with description.""" + + scenario_metadata = ScenarioMetadata( + name="test_scenario", + class_name="TestScenario", + description="This is a test scenario", + default_strategy="", + all_strategies=(), + aggregate_strategies=(), + default_datasets=(), + max_dataset_size=None, + ) - frontend_core.format_scenario_info(scenario_info=scenario_info) + frontend_core.format_scenario_metadata(scenario_metadata=scenario_metadata) captured = capsys.readouterr() assert "This is a test scenario" in captured.out - def test_format_scenario_info_with_strategies(self, capsys): - """Test format_scenario_info with strategies.""" - scenario_info = { - "name": "test_scenario", - "class_name": "TestScenario", - "all_strategies": ["strategy1", "strategy2"], - "default_strategy": "strategy1", - } + def test_format_scenario_metadata_with_strategies(self, capsys): + """Test format_scenario_metadata with strategies.""" + scenario_metadata = ScenarioMetadata( + name="test_scenario", + class_name="TestScenario", + description="", + default_strategy="strategy1", + all_strategies=("strategy1", "strategy2"), + aggregate_strategies=(), + default_datasets=(), + max_dataset_size=None, + ) - frontend_core.format_scenario_info(scenario_info=scenario_info) + frontend_core.format_scenario_metadata(scenario_metadata=scenario_metadata) captured = capsys.readouterr() assert "strategy1" in captured.out assert "strategy2" in captured.out assert "Default Strategy" in captured.out - def test_format_initializer_info_basic(self, capsys) -> None: - """Test format_initializer_info with basic info.""" - from pyrit.cli.initializer_registry import InitializerInfo - - initializer_info: InitializerInfo = { - "name": "test_init", - "class_name": "TestInit", - "initializer_name": "test", - "description": "", - "required_env_vars": [], - "execution_order": 100, - } + def test_format_initializer_metadata_basic(self, capsys) -> None: + """Test format_initializer_metadata with basic metadata.""" + initializer_metadata = InitializerMetadata( + name="test_init", + class_name="TestInit", + description="", + initializer_name="test", + required_env_vars=(), + execution_order=100, + ) - frontend_core.format_initializer_info(initializer_info=initializer_info) + frontend_core.format_initializer_metadata(initializer_metadata=initializer_metadata) captured = capsys.readouterr() assert "test_init" in captured.out assert "TestInit" in captured.out assert "100" in captured.out - def test_format_initializer_info_with_env_vars(self, capsys) -> None: - """Test format_initializer_info with environment variables.""" - from pyrit.cli.initializer_registry import InitializerInfo - - initializer_info: InitializerInfo = { - "name": "test_init", - "class_name": "TestInit", - "initializer_name": "test", - "description": "", - "required_env_vars": ["VAR1", "VAR2"], - "execution_order": 100, - } + def test_format_initializer_metadata_with_env_vars(self, capsys) -> None: + """Test format_initializer_metadata with environment variables.""" + initializer_metadata = InitializerMetadata( + name="test_init", + class_name="TestInit", + description="", + initializer_name="test", + required_env_vars=("VAR1", "VAR2"), + execution_order=100, + ) - frontend_core.format_initializer_info(initializer_info=initializer_info) + frontend_core.format_initializer_metadata(initializer_metadata=initializer_metadata) captured = capsys.readouterr() assert "VAR1" in captured.out assert "VAR2" in captured.out - def test_format_initializer_info_with_description(self, capsys) -> None: - """Test format_initializer_info with description.""" - from pyrit.cli.initializer_registry import InitializerInfo - - initializer_info: InitializerInfo = { - "name": "test_init", - "class_name": "TestInit", - "initializer_name": "test", - "description": "Test description", - "required_env_vars": [], - "execution_order": 100, - } + def test_format_initializer_metadata_with_description(self, capsys) -> None: + """Test format_initializer_metadata with description.""" + initializer_metadata = InitializerMetadata( + name="test_init", + class_name="TestInit", + description="Test description", + initializer_name="test", + required_env_vars=(), + execution_order=100, + ) - frontend_core.format_initializer_info(initializer_info=initializer_info) + frontend_core.format_initializer_metadata(initializer_metadata=initializer_metadata) captured = capsys.readouterr() assert "Test description" in captured.out @@ -620,7 +639,7 @@ async def test_run_scenario_async_basic( mock_scenario_instance.initialize_async = AsyncMock() mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_registry.get_scenario.return_value = mock_scenario_class + mock_scenario_registry.get_class.return_value = mock_scenario_class mock_printer_class.return_value = mock_printer context._scenario_registry = mock_scenario_registry @@ -645,8 +664,8 @@ async def test_run_scenario_async_not_found(self, mock_init_pyrit: AsyncMock): """Test running non-existent scenario raises ValueError.""" context = frontend_core.FrontendCore() mock_scenario_registry = MagicMock() - mock_scenario_registry.get_scenario.return_value = None - mock_scenario_registry.get_scenario_names.return_value = ["other_scenario"] + mock_scenario_registry.get_class.return_value = None + mock_scenario_registry.get_names.return_value = ["other_scenario"] context._scenario_registry = mock_scenario_registry context._initializer_registry = MagicMock() @@ -684,7 +703,7 @@ class MockStrategy(Enum): mock_scenario_instance.initialize_async = AsyncMock() mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_registry.get_scenario.return_value = mock_scenario_class + mock_scenario_registry.get_class.return_value = mock_scenario_class mock_printer_class.return_value = mock_printer context._scenario_registry = mock_scenario_registry @@ -722,12 +741,12 @@ async def test_run_scenario_async_with_initializers( mock_printer.print_summary_async = AsyncMock() mock_initializer_class = MagicMock() - mock_initializer_registry.get_initializer_class.return_value = mock_initializer_class + mock_initializer_registry.get_class.return_value = mock_initializer_class mock_scenario_instance.initialize_async = AsyncMock() mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_registry.get_scenario.return_value = mock_scenario_class + mock_scenario_registry.get_class.return_value = mock_scenario_class mock_printer_class.return_value = mock_printer context._scenario_registry = mock_scenario_registry @@ -741,7 +760,7 @@ async def test_run_scenario_async_with_initializers( ) # Verify initializer was retrieved - mock_initializer_registry.get_initializer_class.assert_called_once_with(name="test_init") + mock_initializer_registry.get_class.assert_called_once_with("test_init") @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) @patch("pyrit.scenario.printer.console_printer.ConsoleScenarioResultPrinter") @@ -762,7 +781,7 @@ async def test_run_scenario_async_with_max_concurrency( mock_scenario_instance.initialize_async = AsyncMock() mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_registry.get_scenario.return_value = mock_scenario_class + mock_scenario_registry.get_class.return_value = mock_scenario_class mock_printer_class.return_value = mock_printer context._scenario_registry = mock_scenario_registry @@ -800,7 +819,7 @@ async def test_run_scenario_async_without_print_summary( mock_scenario_instance.initialize_async = AsyncMock() mock_scenario_instance.run_async = AsyncMock(return_value=mock_result) mock_scenario_class.return_value = mock_scenario_instance - mock_scenario_registry.get_scenario.return_value = mock_scenario_class + mock_scenario_registry.get_class.return_value = mock_scenario_class mock_printer_class.return_value = mock_printer context._scenario_registry = mock_scenario_registry diff --git a/tests/unit/cli/test_initializer_registry.py b/tests/unit/cli/test_initializer_registry.py deleted file mode 100644 index c87818f56..000000000 --- a/tests/unit/cli/test_initializer_registry.py +++ /dev/null @@ -1,620 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Unit tests for the InitializerRegistry module. -""" - -from pathlib import Path -from unittest.mock import MagicMock, patch - -import pytest - -from pyrit.cli.initializer_registry import InitializerInfo, InitializerRegistry -from pyrit.setup.initializers.pyrit_initializer import PyRITInitializer - - -class MockInitializer(PyRITInitializer): - """Mock initializer for testing.""" - - def __init__( - self, - *, - mock_name: str = "test_initializer", - mock_description: str = "Test description", - mock_required_env_vars: list[str] | None = None, - mock_execution_order: int = 100, - ) -> None: - """Initialize mock initializer.""" - super().__init__() - self._mock_name = mock_name - self._mock_description = mock_description - self._mock_required_env_vars = mock_required_env_vars or [] - self._mock_execution_order = mock_execution_order - - @property - def name(self) -> str: - """Get the name.""" - return self._mock_name - - @property - def description(self) -> str: - """Get the description.""" - return self._mock_description - - @property - def required_env_vars(self) -> list[str]: - """Get required environment variables.""" - return self._mock_required_env_vars - - @property - def execution_order(self) -> int: - """Get execution order.""" - return self._mock_execution_order - - async def initialize_async(self) -> None: - """Mock initialization.""" - pass - - -class TestInitializerRegistry: - """Tests for InitializerRegistry class.""" - - @patch("pyrit.cli.initializer_registry.Path") - def test_init_with_nonexistent_directory(self, mock_path_class): - """Test initialization when scenarios directory doesn't exist.""" - # Create a mock path that represents a non-existent directory - mock_path = MagicMock() - mock_path.exists.return_value = False - mock_path.is_file.return_value = False - mock_path.is_dir.return_value = False - - # Make Path() constructor and division operations return the mock path - mock_path_class.return_value = mock_path - mock_path.__truediv__ = MagicMock(return_value=mock_path) - - registry = InitializerRegistry() - - assert registry._initializers == {} - - @patch("pyrit.cli.initializer_registry.Path") - def test_init_discovers_initializers(self, mock_path_class): - """Test initialization discovers initializers.""" - # Mock the directory structure - mock_path = MagicMock() - mock_path.exists.return_value = True - mock_path.is_file.return_value = False - mock_path.is_dir.return_value = True - - # Create a mock file - mock_file = MagicMock() - mock_file.is_file.return_value = True - mock_file.is_dir.return_value = False - mock_file.suffix = ".py" - mock_file.stem = "test_init" - - mock_path.iterdir.return_value = [mock_file] - mock_path_class.return_value = mock_path - - with patch.object(InitializerRegistry, "_process_file"): - registry = InitializerRegistry() - # Just verify it attempted to process files - assert registry is not None - - def test_get_initializer_existing(self) -> None: - """Test getting an existing initializer.""" - registry = InitializerRegistry() - test_info: InitializerInfo = { - "name": "test", - "class_name": "TestInitializer", - "initializer_name": "test_init", - "description": "Test", - "required_env_vars": [], - "execution_order": 100, - } - registry._initializers["test"] = test_info - - result = registry.get_initializer("test") - assert result == test_info - - def test_get_initializer_nonexistent(self): - """Test getting a non-existent initializer returns None.""" - registry = InitializerRegistry() - result = registry.get_initializer("nonexistent") - assert result is None - - def test_get_initializer_names_empty(self): - """Test get_initializer_names with no initializers.""" - registry = InitializerRegistry() - registry._initializers = {} - names = registry.get_initializer_names() - assert names == [] - - def test_get_initializer_names_sorted(self) -> None: - """Test get_initializer_names returns sorted list.""" - registry = InitializerRegistry() - test_info: InitializerInfo = { - "name": "test", - "class_name": "Test", - "initializer_name": "test", - "description": "Test", - "required_env_vars": [], - "execution_order": 100, - } - registry._initializers = { - "zebra": test_info, - "apple": test_info, - "middle": test_info, - } - - names = registry.get_initializer_names() - assert names == ["apple", "middle", "zebra"] - - def test_list_initializers_sorted_by_execution_order(self): - """Test list_initializers returns sorted list by execution order.""" - registry = InitializerRegistry() - registry._initializers = { - "first": { - "name": "first", - "class_name": "First", - "initializer_name": "first", - "description": "First", - "required_env_vars": [], - "execution_order": 10, - }, - "third": { - "name": "third", - "class_name": "Third", - "initializer_name": "third", - "description": "Third", - "required_env_vars": [], - "execution_order": 30, - }, - "second": { - "name": "second", - "class_name": "Second", - "initializer_name": "second", - "description": "Second", - "required_env_vars": [], - "execution_order": 20, - }, - } - - initializers = registry.list_initializers() - - assert len(initializers) == 3 - assert initializers[0]["name"] == "first" - assert initializers[1]["name"] == "second" - assert initializers[2]["name"] == "third" - - def test_list_initializers_sorted_by_name_when_same_order(self): - """Test list_initializers sorts by name when execution order is same.""" - registry = InitializerRegistry() - registry._initializers = { - "zebra": { - "name": "zebra", - "class_name": "Zebra", - "initializer_name": "zebra", - "description": "Zebra", - "required_env_vars": [], - "execution_order": 10, - }, - "apple": { - "name": "apple", - "class_name": "Apple", - "initializer_name": "apple", - "description": "Apple", - "required_env_vars": [], - "execution_order": 10, - }, - } - - initializers = registry.list_initializers() - - assert len(initializers) == 2 - assert initializers[0]["name"] == "apple" - assert initializers[1]["name"] == "zebra" - - def test_list_initializers_empty(self): - """Test list_initializers with no initializers.""" - registry = InitializerRegistry() - registry._initializers = {} - - initializers = registry.list_initializers() - assert initializers == [] - - def test_discover_in_directory_skips_init_files(self): - """Test that __init__.py files are skipped.""" - registry = InitializerRegistry() - - mock_init_file = MagicMock() - mock_init_file.is_file.return_value = True - mock_init_file.is_dir.return_value = False - mock_init_file.suffix = ".py" - mock_init_file.stem = "__init__" - - mock_directory = MagicMock() - mock_directory.iterdir.return_value = [mock_init_file] - - with patch.object(registry, "_process_file") as mock_process: - registry._discover_in_directory(directory=mock_directory) - # Should not process __init__.py - mock_process.assert_not_called() - - def test_discover_in_directory_skips_pycache(self): - """Test that __pycache__ directories are skipped.""" - registry = InitializerRegistry() - - mock_pycache = MagicMock() - mock_pycache.is_file.return_value = False - mock_pycache.is_dir.return_value = True - mock_pycache.name = "__pycache__" - - mock_directory = MagicMock() - mock_directory.iterdir.return_value = [mock_pycache] - - with patch.object(registry, "_discover_in_directory") as mock_discover: - registry._discover_in_directory(directory=mock_directory) - # Should only be called once (the initial call), not recursively for __pycache__ - assert mock_discover.call_count == 1 - - def test_discover_in_directory_processes_valid_files(self): - """Test that valid Python files are processed.""" - registry = InitializerRegistry() - - mock_file = MagicMock() - mock_file.is_file.return_value = True - mock_file.is_dir.return_value = False - mock_file.suffix = ".py" - mock_file.stem = "valid_init" - - mock_directory = MagicMock() - mock_directory.iterdir.return_value = [mock_file] - - with patch.object(registry, "_process_file") as mock_process: - registry._discover_in_directory(directory=mock_directory) - mock_process.assert_called_once_with(file_path=mock_file) - - def test_discover_in_directory_recurses_subdirectories(self): - """Test that subdirectories are recursively processed.""" - registry = InitializerRegistry() - - mock_subdir = MagicMock() - mock_subdir.is_file.return_value = False - mock_subdir.is_dir.return_value = True - mock_subdir.name = "subdir" - - mock_directory = MagicMock() - mock_directory.iterdir.return_value = [mock_subdir] - - with patch.object(registry, "_discover_in_directory", wraps=registry._discover_in_directory) as mock_discover: - # Call once to start - registry._discover_in_directory(directory=mock_directory) - # Should be called twice: initial + recursive call for subdir - assert mock_discover.call_count == 2 - # Second call should have the subdirectory - second_call_kwargs = mock_discover.call_args_list[1][1] - assert second_call_kwargs["directory"] == mock_subdir - - @patch("pyrit.cli.initializer_registry.importlib.util.spec_from_file_location") - @patch("pyrit.cli.initializer_registry.PYRIT_PATH", "/fake/pyrit") - def test_process_file_handles_import_errors(self, mock_spec): - """Test that import errors are handled gracefully.""" - registry = InitializerRegistry() - registry._initializers.clear() - registry._initializer_paths.clear() - - mock_spec.side_effect = Exception("Import error") - - # Create a proper Path object that can handle relative_to - mock_file = Path("/fake/pyrit/setup/initializers/scenarios/broken.py") - - # Should not raise exception - registry._process_file(file_path=mock_file) - - assert "broken" not in registry._initializers - - @patch("pyrit.cli.initializer_registry.importlib.util.spec_from_file_location") - @patch("pyrit.cli.initializer_registry.PYRIT_PATH", "/fake/pyrit") - def test_process_file_handles_no_spec(self, mock_spec): - """Test handling when spec_from_file_location returns None.""" - registry = InitializerRegistry() - registry._initializers.clear() - registry._initializer_paths.clear() - - mock_spec.return_value = None - - # Create a proper Path object - mock_file = Path("/fake/pyrit/setup/initializers/scenarios/no_spec.py") - - registry._process_file(file_path=mock_file) - - assert "no_spec" not in registry._initializers - - @patch("pyrit.cli.initializer_registry.importlib.util.spec_from_file_location") - @patch("pyrit.cli.initializer_registry.PYRIT_PATH", "/fake/pyrit") - def test_process_file_discovers_initializer_class(self, mock_spec): - """Test that PyRITInitializer subclasses are discovered.""" - registry = InitializerRegistry() - registry._initializers.clear() - registry._initializer_paths.clear() - - # Create a mock module with our test initializer - mock_module = MagicMock() - mock_module.MockInitializer = MockInitializer - - mock_spec_obj = MagicMock() - mock_spec_obj.loader = MagicMock() - mock_spec.return_value = mock_spec_obj - - # Create a proper Path object - mock_file = Path("/fake/pyrit/setup/initializers/scenarios/test_init.py") - - with patch("pyrit.cli.initializer_registry.importlib.util.module_from_spec", return_value=mock_module): - with patch("pyrit.cli.initializer_registry.dir", return_value=["MockInitializer"]): - with patch("pyrit.cli.initializer_registry.getattr", return_value=MockInitializer): - with patch("pyrit.cli.initializer_registry.inspect.isclass", return_value=True): - registry._process_file(file_path=mock_file) - - # Verify the initializer was registered - assert "test_init" in registry._initializers - - def test_try_register_initializer_success(self): - """Test successful registration of an initializer.""" - registry = InitializerRegistry() - mock_path = Path("/fake/path/test.py") - - # Clear any auto-discovered initializers to ensure clean test - registry._initializers.clear() - registry._initializer_paths.clear() - - registry._try_register_initializer(initializer_class=MockInitializer, short_name="test", file_path=mock_path) - - assert "test" in registry._initializers - info = registry._initializers["test"] - assert info["name"] == "test" - assert info["class_name"] == "MockInitializer" - assert info["initializer_name"] == "test_initializer" - assert info["description"] == "Test description" - assert info["required_env_vars"] == [] - assert info["execution_order"] == 100 - assert registry._initializer_paths["test"] == mock_path - - def test_try_register_initializer_with_env_vars(self): - """Test registration with required environment variables.""" - registry = InitializerRegistry() - mock_path = Path("/fake/path/env_test.py") - - class EnvVarInitializer(PyRITInitializer): - @property - def name(self) -> str: - return "env_test" - - @property - def description(self) -> str: - return "Test with env vars" - - @property - def required_env_vars(self) -> list[str]: - return ["API_KEY", "ENDPOINT"] - - @property - def execution_order(self) -> int: - return 50 - - async def initialize_async(self) -> None: - pass - - # Clear any auto-discovered initializers to ensure clean test - registry._initializers.clear() - registry._initializer_paths.clear() - - registry._try_register_initializer( - initializer_class=EnvVarInitializer, short_name="env_test", file_path=mock_path - ) - - assert "env_test" in registry._initializers - info = registry._initializers["env_test"] - assert info["required_env_vars"] == ["API_KEY", "ENDPOINT"] - assert info["execution_order"] == 50 - - def test_try_register_initializer_handles_instantiation_error(self): - """Test that instantiation errors are handled gracefully.""" - registry = InitializerRegistry() - mock_path = Path("/fake/path/broken.py") - - class BrokenInitializer(PyRITInitializer): - def __init__(self) -> None: - raise ValueError("Cannot instantiate") - - @property - def name(self) -> str: - return "broken" - - def initialize(self) -> None: - pass - - # Should not raise exception - registry._try_register_initializer( - initializer_class=BrokenInitializer, short_name="broken", file_path=mock_path - ) - - # Should not be registered - assert "broken" not in registry._initializers - - def test_initializer_info_typed_dict_structure(self) -> None: - """Test that InitializerInfo TypedDict has correct structure.""" - info: InitializerInfo = { - "name": "test", - "class_name": "TestClass", - "initializer_name": "test_init", - "description": "Description", - "required_env_vars": ["VAR1"], - "execution_order": 10, - } - - assert info["name"] == "test" - assert info["class_name"] == "TestClass" - assert info["initializer_name"] == "test_init" - assert info["description"] == "Description" - assert info["required_env_vars"] == ["VAR1"] - assert info["execution_order"] == 10 - - -class TestResolveInitializerPaths: - """Tests for resolve_initializer_paths method.""" - - def test_resolve_single_initializer(self): - """Test resolving a single valid initializer name.""" - registry = InitializerRegistry() - registry._initializers.clear() - registry._initializer_paths.clear() - - test_path = Path("/fake/simple.py") - registry._initializers["simple"] = { - "name": "simple", - "class_name": "SimpleInitializer", - "initializer_name": "simple_init", - "description": "Test", - "required_env_vars": [], - "execution_order": 100, - } - registry._initializer_paths["simple"] = test_path - - result = registry.resolve_initializer_paths(initializer_names=["simple"]) - - assert len(result) == 1 - assert result[0] == test_path - - def test_resolve_multiple_initializers(self): - """Test resolving multiple initializer names.""" - registry = InitializerRegistry() - registry._initializers.clear() - registry._initializer_paths.clear() - - path1 = Path("/fake/simple.py") - path2 = Path("/fake/objective_target.py") - - registry._initializers["simple"] = { - "name": "simple", - "class_name": "SimpleInitializer", - "initializer_name": "simple_init", - "description": "Test", - "required_env_vars": [], - "execution_order": 100, - } - registry._initializer_paths["simple"] = path1 - - registry._initializers["objective_target"] = { - "name": "objective_target", - "class_name": "ObjectiveTargetInitializer", - "initializer_name": "obj_target_init", - "description": "Test", - "required_env_vars": [], - "execution_order": 200, - } - registry._initializer_paths["objective_target"] = path2 - - result = registry.resolve_initializer_paths(initializer_names=["simple", "objective_target"]) - - assert len(result) == 2 - assert path1 in result - assert path2 in result - - def test_resolve_invalid_initializer_name(self): - """Test resolving an invalid initializer name raises ValueError.""" - registry = InitializerRegistry() - registry._initializers.clear() - registry._initializer_paths.clear() - - with pytest.raises(ValueError, match="Built-in initializer 'invalid' not found"): - registry.resolve_initializer_paths(initializer_names=["invalid"]) - - def test_resolve_initializer_without_file_path(self): - """Test resolving initializer without file path raises ValueError.""" - registry = InitializerRegistry() - registry._initializers.clear() - registry._initializer_paths.clear() - - registry._initializers["simple"] = { - "name": "simple", - "class_name": "SimpleInitializer", - "initializer_name": "simple_init", - "description": "Test", - "required_env_vars": [], - "execution_order": 100, - } - # Intentionally not adding to _initializer_paths - - with pytest.raises(ValueError, match="Could not locate file for initializer 'simple'"): - registry.resolve_initializer_paths(initializer_names=["simple"]) - - -class TestResolveScriptPaths: - """Tests for resolve_script_paths static method.""" - - @patch("pyrit.cli.initializer_registry.Path") - def test_resolve_absolute_path_exists(self, mock_path_class): - """Test resolving absolute path that exists.""" - mock_path = MagicMock() - mock_path.is_absolute.return_value = True - mock_path.exists.return_value = True - mock_path_class.return_value = mock_path - - result = InitializerRegistry.resolve_script_paths(script_paths=["/absolute/script.py"]) - - assert len(result) == 1 - - @patch("pyrit.cli.initializer_registry.Path") - def test_resolve_relative_path_exists(self, mock_path_class): - """Test resolving relative path that exists.""" - mock_path = MagicMock() - mock_path.is_absolute.return_value = False - mock_path.exists.return_value = True - - mock_cwd = MagicMock() - mock_resolved_path = MagicMock() - mock_resolved_path.exists.return_value = True - - mock_cwd.__truediv__ = lambda self, other: mock_resolved_path - mock_path_class.return_value = mock_path - mock_path_class.cwd.return_value = mock_cwd - - result = InitializerRegistry.resolve_script_paths(script_paths=["script.py"]) - - assert len(result) == 1 - - @patch("pyrit.cli.initializer_registry.Path") - def test_resolve_path_not_exists(self, mock_path_class): - """Test resolving path that doesn't exist raises FileNotFoundError.""" - mock_path = MagicMock() - mock_path.is_absolute.return_value = True - mock_path.exists.return_value = False - mock_path.absolute.return_value = "/fake/missing.py" - mock_path_class.return_value = mock_path - - with pytest.raises(FileNotFoundError, match="Initialization script not found"): - InitializerRegistry.resolve_script_paths(script_paths=["/fake/missing.py"]) - - @patch("pyrit.cli.initializer_registry.Path") - def test_resolve_multiple_paths(self, mock_path_class): - """Test resolving multiple script paths.""" - mock_path1 = MagicMock() - mock_path1.is_absolute.return_value = True - mock_path1.exists.return_value = True - - mock_path2 = MagicMock() - mock_path2.is_absolute.return_value = True - mock_path2.exists.return_value = True - - # Make Path() return different mocks for different inputs - def path_side_effect(path_str): - if "script1" in str(path_str): - return mock_path1 - return mock_path2 - - mock_path_class.side_effect = path_side_effect - - result = InitializerRegistry.resolve_script_paths(script_paths=["/fake/script1.py", "/fake/script2.py"]) - - assert len(result) == 2 diff --git a/tests/unit/cli/test_pyrit_scan.py b/tests/unit/cli/test_pyrit_scan.py index 00893d543..58517c523 100644 --- a/tests/unit/cli/test_pyrit_scan.py +++ b/tests/unit/cli/test_pyrit_scan.py @@ -390,7 +390,7 @@ class TestMainIntegration: """Integration-style tests for main function.""" @patch("pyrit.cli.frontend_core.print_scenarios_list_async", new_callable=AsyncMock) - @patch("pyrit.cli.scenario_registry.ScenarioRegistry") + @patch("pyrit.registry.ScenarioRegistry") @patch("pyrit.setup.initialize_pyrit_async", new_callable=AsyncMock) def test_main_list_scenarios_integration( self, diff --git a/tests/unit/cli/test_scenario_registry.py b/tests/unit/cli/test_scenario_registry.py deleted file mode 100644 index ca8156367..000000000 --- a/tests/unit/cli/test_scenario_registry.py +++ /dev/null @@ -1,333 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -""" -Unit tests for the ScenarioRegistry module. -""" - -from typing import Type -from unittest.mock import MagicMock, patch - -import pytest - -from pyrit.cli.scenario_registry import ScenarioRegistry -from pyrit.scenario.core.scenario import Scenario -from pyrit.scenario.core.scenario_strategy import ScenarioStrategy - - -class MockStrategy(ScenarioStrategy): - """Mock strategy for testing.""" - - ALL = ("all", {"all"}) - TestStrategy = ("test_strategy", {"test"}) - - @classmethod - def get_aggregate_tags(cls) -> set[str]: - return {"all"} - - -class MockScenario(Scenario): - """Mock scenario for testing.""" - - async def _get_atomic_attacks_async(self): - return [] - - @classmethod - def get_strategy_class(cls) -> Type[ScenarioStrategy]: - return MockStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - return MockStrategy.ALL - - @classmethod - def default_dataset_config(cls): - from pyrit.scenario.core.dataset_configuration import DatasetConfiguration - - return DatasetConfiguration(dataset_names=[]) - - -class TestScenarioRegistry: - """Tests for ScenarioRegistry class.""" - - def test_discover_builtin_scenarios(self): - """Test discovery of built-in scenarios.""" - # Create a registry which will automatically discover built-in scenarios - registry = ScenarioRegistry() - - # Verify that some scenarios were discovered - scenario_names = registry.get_scenario_names() - # We should find at least the built-in scenarios - # Note: This is an integration test that depends on actual scenario files existing - assert len(scenario_names) >= 0 # May be 0 if run in isolation - - # If scenarios were found, verify they're Scenario subclasses - for name in scenario_names: - scenario_class = registry.get_scenario(name) - assert scenario_class is not None - assert issubclass(scenario_class, Scenario) - - def test_discover_builtin_scenarios_correct_module_paths(self): - """Test that builtin scenario discovery uses correct module paths without duplication. - - This is a regression test for a bug where module paths were incorrectly constructed - as 'pyrit.scenario.scenarios.pyrit.scenarios.scenarios.xxx' instead of - 'pyrit.scenario.scenarios.xxx'. - """ - registry = ScenarioRegistry() - registry._discover_builtin_scenarios() - - # Verify that scenarios were discovered - assert len(registry._scenarios) > 0, "No scenarios were discovered" - - # Check that some expected scenarios are present - # These are real scenarios that exist in the codebase - discovered_names = list(registry._scenarios.keys()) - - # Verify naming convention: should not have duplicated path components - for scenario_name in discovered_names: - # Should not see 'pyrit' in the scenario name (it's just relative path) - assert "pyrit" not in scenario_name.lower(), f"Scenario name has 'pyrit' in it: {scenario_name}" - - # Should not see 'scenario.scenarios' duplication - assert "scenarios.scenarios" not in scenario_name, ( - f"Scenario name has duplicated 'scenarios': {scenario_name}" - ) - - # Verify that nested scenarios use dot notation (e.g., "airt.content_harms") - nested_scenarios = [name for name in discovered_names if "." in name] - - # Should have nested scenarios (all scenarios are now in subdirectories) - assert len(nested_scenarios) > 0, "No nested scenarios found" - - def test_get_scenario_existing(self): - """Test getting an existing scenario.""" - registry = ScenarioRegistry() - # Manually add a scenario for testing - registry._scenarios["test_scenario"] = MockScenario - - result = registry.get_scenario("test_scenario") - assert result == MockScenario - - def test_get_scenario_nonexistent(self): - """Test getting a non-existent scenario returns None.""" - registry = ScenarioRegistry() - result = registry.get_scenario("nonexistent_scenario") - assert result is None - - def test_get_scenario_names_empty(self): - """Test get_scenario_names with no scenarios.""" - registry = ScenarioRegistry() - registry._scenarios = {} - registry._discovered = True # Prevent auto-discovery - names = registry.get_scenario_names() - assert names == [] - - def test_get_scenario_names_sorted(self): - """Test get_scenario_names returns sorted list.""" - registry = ScenarioRegistry() - registry._scenarios = { - "zebra_scenario": MockScenario, - "apple_scenario": MockScenario, - "middle_scenario": MockScenario, - } - registry._discovered = True # Prevent auto-discovery - - names = registry.get_scenario_names() - assert names == ["apple_scenario", "middle_scenario", "zebra_scenario"] - - def test_list_scenarios_with_descriptions(self): - """Test list_scenarios returns scenario information.""" - - class DocumentedScenario(Scenario): - """This is a test scenario for unit testing.""" - - async def _get_atomic_attacks_async(self): - return [] - - @classmethod - def get_strategy_class(cls) -> Type[ScenarioStrategy]: - return MockStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - return MockStrategy.ALL - - @classmethod - def default_dataset_config(cls): - from pyrit.scenario.core.dataset_configuration import ( - DatasetConfiguration, - ) - - return DatasetConfiguration(dataset_names=["test_dataset_1", "test_dataset_2"]) - - registry = ScenarioRegistry() - registry._scenarios = { - "test_scenario": DocumentedScenario, - } - registry._discovered = True # Prevent auto-discovery - - scenarios = registry.list_scenarios() - - assert len(scenarios) == 1 - assert scenarios[0]["name"] == "test_scenario" - assert scenarios[0]["class_name"] == "DocumentedScenario" - assert "test scenario" in scenarios[0]["description"].lower() - assert scenarios[0]["default_datasets"] == ["test_dataset_1", "test_dataset_2"] - - def test_list_scenarios_no_description(self): - """Test list_scenarios with scenario lacking docstring.""" - - class UndocumentedScenario(Scenario): - async def _get_atomic_attacks_async(self): - return [] - - @classmethod - def get_strategy_class(cls) -> Type[ScenarioStrategy]: - return MockStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - return MockStrategy.ALL - - @classmethod - def default_dataset_config(cls): - from pyrit.scenario.core.dataset_configuration import ( - DatasetConfiguration, - ) - - return DatasetConfiguration(dataset_names=[]) - - # Remove docstring - UndocumentedScenario.__doc__ = None - - registry = ScenarioRegistry() - registry._scenarios = {"undocumented": UndocumentedScenario} - registry._discovered = True # Prevent auto-discovery - - scenarios = registry.list_scenarios() - - assert len(scenarios) == 1 - assert scenarios[0]["description"] == "No description available" - - def test_list_scenarios_with_required_datasets_error(self): - """Test list_scenarios raises error when default_dataset_config fails.""" - - class BrokenScenario(Scenario): - """Scenario that raises error on default_dataset_config.""" - - async def _get_atomic_attacks_async(self): - return [] - - @classmethod - def get_strategy_class(cls) -> Type[ScenarioStrategy]: - return MockStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - return MockStrategy.ALL - - @classmethod - def default_dataset_config(cls): - raise ValueError("Cannot get datasets") - - registry = ScenarioRegistry() - registry._scenarios = {"broken": BrokenScenario} - registry._discovered = True - - # Should raise the exception instead of catching it - with pytest.raises(ValueError, match="Cannot get datasets"): - registry.list_scenarios() - - def test_class_name_to_scenario_name_with_scenario_suffix(self): - """Test converting class name with 'Scenario' suffix.""" - registry = ScenarioRegistry() - result = registry._class_name_to_scenario_name("EncodingScenario") - assert result == "encoding" - - def test_class_name_to_scenario_name_without_scenario_suffix(self): - """Test converting class name without 'Scenario' suffix.""" - registry = ScenarioRegistry() - result = registry._class_name_to_scenario_name("CustomTest") - assert result == "custom_test" - - def test_class_name_to_scenario_name_camelcase(self): - """Test converting CamelCase to snake_case.""" - registry = ScenarioRegistry() - result = registry._class_name_to_scenario_name("MyCustomScenario") - assert result == "my_custom" - - def test_class_name_to_scenario_name_with_numbers(self): - """Test converting class name with numbers.""" - registry = ScenarioRegistry() - result = registry._class_name_to_scenario_name("Test123Scenario") - assert result == "test123" - - def test_discover_user_scenarios_no_modules(self): - """Test discover_user_scenarios with no user modules.""" - registry = ScenarioRegistry() - registry._scenarios = {} - - # Should not raise an exception - registry.discover_user_scenarios() - - @patch("pyrit.cli.scenario_registry.inspect.getmembers") - def test_discover_user_scenarios_with_user_class(self, mock_getmembers): - """Test discover_user_scenarios finds user-defined scenarios.""" - - class UserScenario(Scenario): - """User-defined scenario.""" - - async def _get_atomic_attacks_async(self): - return [] - - @classmethod - def get_strategy_class(cls) -> Type[ScenarioStrategy]: - return MockStrategy - - @classmethod - def get_default_strategy(cls) -> ScenarioStrategy: - return MockStrategy.ALL - - @classmethod - def default_dataset_config(cls): - from pyrit.scenario.core.dataset_configuration import ( - DatasetConfiguration, - ) - - return DatasetConfiguration(dataset_names=[]) - - UserScenario.__module__ = "user_module" - - mock_module = MagicMock() - mock_module.__dict__ = {} - - # Need to patch sys.modules which is imported inside the function - import sys - - original_modules = sys.modules.copy() - try: - sys.modules["user_module"] = mock_module - mock_getmembers.return_value = [("UserScenario", UserScenario)] - - registry = ScenarioRegistry() - registry._scenarios = {} - registry.discover_user_scenarios() - - # Verify user scenario was registered - assert "user" in registry._scenarios - finally: - # Restore original sys.modules - sys.modules.clear() - sys.modules.update(original_modules) - - def test_discover_user_scenarios_skips_builtins(self): - """Test discover_user_scenarios skips built-in modules.""" - registry = ScenarioRegistry() - initial_count = len(registry._scenarios) - - registry.discover_user_scenarios() - - # Should not add any scenarios from built-in modules - # (may have same count or more if user modules exist, but not from builtins) - assert len(registry._scenarios) >= initial_count diff --git a/tests/unit/registry/__init__.py b/tests/unit/registry/__init__.py new file mode 100644 index 000000000..9a0454564 --- /dev/null +++ b/tests/unit/registry/__init__.py @@ -0,0 +1,2 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. diff --git a/tests/unit/registry/test_base.py b/tests/unit/registry/test_base.py new file mode 100644 index 000000000..e96d695ec --- /dev/null +++ b/tests/unit/registry/test_base.py @@ -0,0 +1,182 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass + +import pytest + +from pyrit.registry.base import RegistryItemMetadata, _matches_filters + + +@dataclass(frozen=True) +class MetadataWithTags(RegistryItemMetadata): + """Test metadata with a tags field for list filtering tests.""" + + tags: tuple[str, ...] + + +class TestMatchesFilters: + """Tests for the _matches_filters function.""" + + def test_matches_filters_exact_match_string(self): + """Test that exact string matches work.""" + metadata = RegistryItemMetadata( + name="test_item", + class_name="TestClass", + description="A test item", + ) + assert _matches_filters(metadata, include_filters={"name": "test_item"}) is True + assert _matches_filters(metadata, include_filters={"class_name": "TestClass"}) is True + + def test_matches_filters_no_match_string(self): + """Test that non-matching strings return False.""" + metadata = RegistryItemMetadata( + name="test_item", + class_name="TestClass", + description="A test item", + ) + assert _matches_filters(metadata, include_filters={"name": "other_item"}) is False + assert _matches_filters(metadata, include_filters={"class_name": "OtherClass"}) is False + + def test_matches_filters_multiple_filters_all_match(self): + """Test that all filters must match.""" + metadata = RegistryItemMetadata( + name="test_item", + class_name="TestClass", + description="A test item", + ) + assert _matches_filters(metadata, include_filters={"name": "test_item", "class_name": "TestClass"}) is True + + def test_matches_filters_multiple_filters_partial_match(self): + """Test that partial matches return False when not all filters match.""" + metadata = RegistryItemMetadata( + name="test_item", + class_name="TestClass", + description="A test item", + ) + assert _matches_filters(metadata, include_filters={"name": "test_item", "class_name": "OtherClass"}) is False + + def test_matches_filters_key_not_in_metadata(self): + """Test that filtering on a non-existent key returns False.""" + metadata = RegistryItemMetadata( + name="test_item", + class_name="TestClass", + description="A test item", + ) + assert _matches_filters(metadata, include_filters={"nonexistent_key": "value"}) is False + + def test_matches_filters_empty_filters(self): + """Test that empty filters return True.""" + metadata = RegistryItemMetadata( + name="test_item", + class_name="TestClass", + description="A test item", + ) + assert _matches_filters(metadata) is True + + def test_matches_filters_list_value_contains_filter(self): + """Test filtering when metadata value is a list and filter value is in the list.""" + metadata = MetadataWithTags( + name="test_item", + class_name="TestClass", + description="A test item", + tags=("tag1", "tag2", "tag3"), + ) + assert _matches_filters(metadata, include_filters={"tags": "tag1"}) is True + assert _matches_filters(metadata, include_filters={"tags": "tag2"}) is True + + def test_matches_filters_list_value_not_contains_filter(self): + """Test filtering when metadata value is a list and filter value is not in the list.""" + metadata = MetadataWithTags( + name="test_item", + class_name="TestClass", + description="A test item", + tags=("tag1", "tag2", "tag3"), + ) + assert _matches_filters(metadata, include_filters={"tags": "missing_tag"}) is False + + def test_matches_filters_exclude_exact_match(self): + """Test that exclude filters work for exact matches.""" + metadata = RegistryItemMetadata( + name="test_item", + class_name="TestClass", + description="A test item", + ) + assert _matches_filters(metadata, exclude_filters={"name": "test_item"}) is False + assert _matches_filters(metadata, exclude_filters={"name": "other_item"}) is True + + def test_matches_filters_exclude_list_value(self): + """Test exclude filters work for list values.""" + metadata = MetadataWithTags( + name="test_item", + class_name="TestClass", + description="A test item", + tags=("tag1", "tag2", "tag3"), + ) + assert _matches_filters(metadata, exclude_filters={"tags": "tag1"}) is False + assert _matches_filters(metadata, exclude_filters={"tags": "missing_tag"}) is True + + def test_matches_filters_exclude_nonexistent_key(self): + """Test that exclude filters for non-existent keys don't exclude the item.""" + metadata = RegistryItemMetadata( + name="test_item", + class_name="TestClass", + description="A test item", + ) + # Non-existent key in exclude filter should not exclude the item + assert _matches_filters(metadata, exclude_filters={"nonexistent_key": "value"}) is True + + def test_matches_filters_combined_include_and_exclude(self): + """Test combined include and exclude filters.""" + metadata = RegistryItemMetadata( + name="test_item", + class_name="TestClass", + description="A test item", + ) + # Include matches, exclude doesn't -> should pass + assert ( + _matches_filters( + metadata, include_filters={"name": "test_item"}, exclude_filters={"class_name": "OtherClass"} + ) + is True + ) + # Include matches, exclude also matches -> should fail + assert ( + _matches_filters( + metadata, include_filters={"name": "test_item"}, exclude_filters={"class_name": "TestClass"} + ) + is False + ) + # Include doesn't match, exclude doesn't match -> should fail (include takes precedence) + assert ( + _matches_filters( + metadata, include_filters={"name": "other_item"}, exclude_filters={"class_name": "OtherClass"} + ) + is False + ) + + +class TestRegistryItemMetadata: + """Tests for the RegistryItemMetadata dataclass.""" + + def test_registry_item_metadata_creation(self): + """Test creating a RegistryItemMetadata instance.""" + metadata = RegistryItemMetadata( + name="test_scorer", + class_name="TestScorer", + description="A test scorer for testing", + ) + assert metadata.name == "test_scorer" + assert metadata.class_name == "TestScorer" + assert metadata.description == "A test scorer for testing" + + def test_registry_item_metadata_is_frozen(self): + """Test that RegistryItemMetadata is immutable.""" + metadata = RegistryItemMetadata( + name="test_item", + class_name="TestClass", + description="Description here", + ) + + with pytest.raises(AttributeError): + metadata.name = "new_name" # type: ignore[misc] diff --git a/tests/unit/registry/test_base_instance_registry.py b/tests/unit/registry/test_base_instance_registry.py new file mode 100644 index 000000000..886de0028 --- /dev/null +++ b/tests/unit/registry/test_base_instance_registry.py @@ -0,0 +1,274 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from dataclasses import dataclass + +from pyrit.registry.base import RegistryItemMetadata +from pyrit.registry.instance_registries.base_instance_registry import BaseInstanceRegistry + + +@dataclass(frozen=True) +class SampleItemMetadata(RegistryItemMetadata): + """Sample metadata with an extra field.""" + + category: str + + +class ConcreteTestRegistry(BaseInstanceRegistry[str, SampleItemMetadata]): + """Concrete implementation of BaseInstanceRegistry for testing.""" + + def _build_metadata(self, name: str, instance: str) -> SampleItemMetadata: + """Build test metadata from a string instance.""" + return SampleItemMetadata( + name=name, + class_name="str", + description=f"Description for {instance}", + category="test" if "test" in instance.lower() else "other", + ) + + +class TestBaseInstanceRegistrySingleton: + """Tests for the singleton pattern in BaseInstanceRegistry.""" + + def setup_method(self): + """Reset the singleton before each test.""" + ConcreteTestRegistry.reset_instance() + + def teardown_method(self): + """Reset the singleton after each test.""" + ConcreteTestRegistry.reset_instance() + + def test_get_registry_singleton_returns_same_instance(self): + """Test that get_registry_singleton returns the same singleton each time.""" + instance1 = ConcreteTestRegistry.get_registry_singleton() + instance2 = ConcreteTestRegistry.get_registry_singleton() + + assert instance1 is instance2 + + def test_reset_instance_clears_singleton(self): + """Test that reset_instance clears the singleton.""" + instance1 = ConcreteTestRegistry.get_registry_singleton() + ConcreteTestRegistry.reset_instance() + instance2 = ConcreteTestRegistry.get_registry_singleton() + + assert instance1 is not instance2 + + def test_reset_instance_when_not_exists_does_not_raise(self): + """Test that reset_instance works even when no instance exists.""" + # Should not raise any exception + ConcreteTestRegistry.reset_instance() + ConcreteTestRegistry.reset_instance() + + +class TestBaseInstanceRegistryRegistration: + """Tests for registration functionality in BaseInstanceRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + ConcreteTestRegistry.reset_instance() + self.registry = ConcreteTestRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + ConcreteTestRegistry.reset_instance() + + def test_register_adds_instance(self): + """Test that register adds an instance to the registry.""" + self.registry.register("test_value", name="test_name") + + assert "test_name" in self.registry + assert self.registry.get("test_name") == "test_value" + + def test_register_multiple_instances(self): + """Test registering multiple instances.""" + self.registry.register("value1", name="name1") + self.registry.register("value2", name="name2") + self.registry.register("value3", name="name3") + + assert len(self.registry) == 3 + assert self.registry.get("name1") == "value1" + assert self.registry.get("name2") == "value2" + assert self.registry.get("name3") == "value3" + + def test_register_overwrites_existing(self): + """Test that registering with the same name overwrites the existing instance.""" + self.registry.register("original", name="name") + self.registry.register("updated", name="name") + + assert len(self.registry) == 1 + assert self.registry.get("name") == "updated" + + def test_register_invalidates_metadata_cache(self): + """Test that registering a new instance invalidates the metadata cache.""" + self.registry.register("value1", name="name1") + # Build cache by calling list_metadata + metadata1 = self.registry.list_metadata() + assert len(metadata1) == 1 + + # Register new instance - should invalidate cache + self.registry.register("value2", name="name2") + metadata2 = self.registry.list_metadata() + + assert len(metadata2) == 2 + + +class TestBaseInstanceRegistryGet: + """Tests for get functionality in BaseInstanceRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + ConcreteTestRegistry.reset_instance() + self.registry = ConcreteTestRegistry.get_registry_singleton() + self.registry.register("test_value", name="test_name") + + def teardown_method(self): + """Reset the singleton after each test.""" + ConcreteTestRegistry.reset_instance() + + def test_get_existing_instance(self): + """Test getting an existing instance by name.""" + result = self.registry.get("test_name") + assert result == "test_value" + + def test_get_nonexistent_returns_none(self): + """Test that getting a non-existent instance returns None.""" + result = self.registry.get("nonexistent") + assert result is None + + +class TestBaseInstanceRegistryGetNames: + """Tests for get_names functionality in BaseInstanceRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + ConcreteTestRegistry.reset_instance() + self.registry = ConcreteTestRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + ConcreteTestRegistry.reset_instance() + + def test_get_names_empty_registry(self): + """Test get_names on an empty registry.""" + names = self.registry.get_names() + assert names == [] + + def test_get_names_returns_sorted_list(self): + """Test that get_names returns a sorted list of names.""" + self.registry.register("value3", name="zeta") + self.registry.register("value1", name="alpha") + self.registry.register("value2", name="beta") + + names = self.registry.get_names() + assert names == ["alpha", "beta", "zeta"] + + +class TestBaseInstanceRegistryListMetadata: + """Tests for list_metadata functionality in BaseInstanceRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + ConcreteTestRegistry.reset_instance() + self.registry = ConcreteTestRegistry.get_registry_singleton() + self.registry.register("test_item_1", name="item1") + self.registry.register("other_item_2", name="item2") + self.registry.register("test_item_3", name="item3") + + def teardown_method(self): + """Reset the singleton after each test.""" + ConcreteTestRegistry.reset_instance() + + def test_list_metadata_returns_all_items(self): + """Test that list_metadata returns metadata for all items.""" + metadata = self.registry.list_metadata() + assert len(metadata) == 3 + + def test_list_metadata_sorted_by_name(self): + """Test that metadata is sorted by name.""" + metadata = self.registry.list_metadata() + names = [m.name for m in metadata] + assert names == ["item1", "item2", "item3"] + + def test_list_metadata_with_filter(self): + """Test filtering metadata by a field.""" + metadata = self.registry.list_metadata(include_filters={"category": "test"}) + assert len(metadata) == 2 + assert all(m.category == "test" for m in metadata) + + def test_list_metadata_filter_no_match(self): + """Test filtering with no matches returns empty list.""" + metadata = self.registry.list_metadata(include_filters={"category": "nonexistent"}) + assert metadata == [] + + def test_list_metadata_with_exclude_filter(self): + """Test excluding metadata by a field.""" + metadata = self.registry.list_metadata(exclude_filters={"category": "test"}) + assert len(metadata) == 1 + assert all(m.category == "other" for m in metadata) + + def test_list_metadata_combined_include_and_exclude(self): + """Test combined include and exclude filters.""" + # Add another test item to have more variety + self.registry.register("another_test_item", name="item4") + + # Get items with category "test" but exclude by name + metadata = self.registry.list_metadata(include_filters={"category": "test"}, exclude_filters={"name": "item1"}) + assert len(metadata) == 2 + assert all(m.category == "test" for m in metadata) + assert all(m.name != "item1" for m in metadata) + + def test_list_metadata_caching(self): + """Test that metadata is cached after first call.""" + # First call builds cache + metadata1 = self.registry.list_metadata() + # Second call uses cache + metadata2 = self.registry.list_metadata() + + # Should be the same list object (cached) + assert metadata1 is metadata2 + + +class TestBaseInstanceRegistryDunderMethods: + """Tests for dunder methods (__contains__, __len__, __iter__) in BaseInstanceRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + ConcreteTestRegistry.reset_instance() + self.registry = ConcreteTestRegistry.get_registry_singleton() + self.registry.register("value1", name="name1") + self.registry.register("value2", name="name2") + + def teardown_method(self): + """Reset the singleton after each test.""" + ConcreteTestRegistry.reset_instance() + + def test_contains_existing_name(self): + """Test __contains__ returns True for existing name.""" + assert "name1" in self.registry + assert "name2" in self.registry + + def test_contains_nonexistent_name(self): + """Test __contains__ returns False for non-existent name.""" + assert "nonexistent" not in self.registry + + def test_len_returns_count(self): + """Test __len__ returns the correct count.""" + assert len(self.registry) == 2 + + def test_len_empty_registry(self): + """Test __len__ returns 0 for empty registry.""" + ConcreteTestRegistry.reset_instance() + empty_registry = ConcreteTestRegistry.get_registry_singleton() + assert len(empty_registry) == 0 + + def test_iter_returns_sorted_names(self): + """Test __iter__ returns names in sorted order.""" + names = list(self.registry) + assert names == ["name1", "name2"] + + def test_iter_allows_for_loop(self): + """Test that the registry can be used in a for loop.""" + collected = [] + for name in self.registry: + collected.append(name) + assert collected == ["name1", "name2"] diff --git a/tests/unit/registry/test_scorer_registry.py b/tests/unit/registry/test_scorer_registry.py new file mode 100644 index 000000000..455841200 --- /dev/null +++ b/tests/unit/registry/test_scorer_registry.py @@ -0,0 +1,376 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Optional + +from pyrit.models import Message, MessagePiece, Score +from pyrit.registry.instance_registries.scorer_registry import ScorerMetadata, ScorerRegistry +from pyrit.score.float_scale.float_scale_scorer import FloatScaleScorer +from pyrit.score.scorer import Scorer +from pyrit.score.scorer_identifier import ScorerIdentifier +from pyrit.score.scorer_prompt_validator import ScorerPromptValidator +from pyrit.score.true_false.true_false_scorer import TrueFalseScorer + + +class DummyValidator(ScorerPromptValidator): + """Dummy validator for testing.""" + + def validate(self, message, objective=None): + pass + + def is_message_piece_supported(self, message_piece): + return True + + +class MockTrueFalseScorer(TrueFalseScorer): + """Mock TrueFalseScorer for testing.""" + + def __init__(self): + super().__init__(validator=DummyValidator()) + + def _build_scorer_identifier(self) -> None: + """Build the scorer evaluation identifier for this mock scorer.""" + self._set_scorer_identifier() + + async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + return [] + + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + return [] + + def validate_return_scores(self, scores: list[Score]): + pass + + +class MockFloatScaleScorer(FloatScaleScorer): + """Mock FloatScaleScorer for testing.""" + + def __init__(self): + super().__init__(validator=DummyValidator()) + + def _build_scorer_identifier(self) -> None: + """Build the scorer evaluation identifier for this mock scorer.""" + self._set_scorer_identifier() + + async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + return [] + + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + return [] + + def validate_return_scores(self, scores: list[Score]): + pass + + +class MockGenericScorer(Scorer): + """Mock generic Scorer (not TrueFalse or FloatScale) for testing.""" + + scorer_type = "true_false" # type: ignore[assignment] + + def __init__(self): + super().__init__(validator=DummyValidator()) + + def _build_scorer_identifier(self) -> None: + """Build the scorer evaluation identifier for this mock scorer.""" + self._set_scorer_identifier() + + async def _score_async(self, message: Message, *, objective: Optional[str] = None) -> list[Score]: + return [] + + async def _score_piece_async(self, message_piece: MessagePiece, *, objective: Optional[str] = None) -> list[Score]: + return [] + + def validate_return_scores(self, scores: list[Score]): + pass + + def get_scorer_metrics(self): + return None + + +class TestScorerRegistrySingleton: + """Tests for the singleton pattern in ScorerRegistry.""" + + def setup_method(self): + """Reset the singleton before each test.""" + ScorerRegistry.reset_instance() + + def teardown_method(self): + """Reset the singleton after each test.""" + ScorerRegistry.reset_instance() + + def test_get_registry_singleton_returns_same_instance(self): + """Test that get_registry_singleton returns the same singleton each time.""" + instance1 = ScorerRegistry.get_registry_singleton() + instance2 = ScorerRegistry.get_registry_singleton() + + assert instance1 is instance2 + + def test_get_registry_singleton_returns_scorer_registry_type(self): + """Test that get_registry_singleton returns a ScorerRegistry instance.""" + instance = ScorerRegistry.get_registry_singleton() + assert isinstance(instance, ScorerRegistry) + + def test_reset_instance_clears_singleton(self): + """Test that reset_instance clears the singleton.""" + instance1 = ScorerRegistry.get_registry_singleton() + ScorerRegistry.reset_instance() + instance2 = ScorerRegistry.get_registry_singleton() + + assert instance1 is not instance2 + + +class TestScorerRegistryRegisterInstance: + """Tests for register_instance functionality in ScorerRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + ScorerRegistry.reset_instance() + self.registry = ScorerRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + ScorerRegistry.reset_instance() + + def test_register_instance_with_custom_name(self): + """Test registering a scorer with a custom name.""" + scorer = MockTrueFalseScorer() + self.registry.register_instance(scorer, name="custom_scorer") + + assert "custom_scorer" in self.registry + assert self.registry.get("custom_scorer") is scorer + + def test_register_instance_generates_name_from_class(self): + """Test that register_instance generates a name from class name when not provided.""" + scorer = MockTrueFalseScorer() + self.registry.register_instance(scorer) + + # Name should be derived from class name with hash suffix + names = self.registry.get_names() + assert len(names) == 1 + assert names[0].startswith("mock_true_false_") + + def test_register_instance_multiple_scorers_unique_names(self): + """Test registering multiple scorers generates unique names.""" + scorer1 = MockTrueFalseScorer() + scorer2 = MockFloatScaleScorer() + + self.registry.register_instance(scorer1) + self.registry.register_instance(scorer2) + + assert len(self.registry) == 2 + + def test_register_instance_same_scorer_type_different_hash(self): + """Test that same scorer class can be registered with different identifiers.""" + scorer1 = MockTrueFalseScorer() + scorer2 = MockTrueFalseScorer() + + # Register with explicit names since scorers may have same hash + self.registry.register_instance(scorer1, name="scorer_1") + self.registry.register_instance(scorer2, name="scorer_2") + + assert len(self.registry) == 2 + + +class TestScorerRegistryGetInstanceByName: + """Tests for get_instance_by_name functionality in ScorerRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + ScorerRegistry.reset_instance() + self.registry = ScorerRegistry.get_registry_singleton() + self.scorer = MockTrueFalseScorer() + self.registry.register_instance(self.scorer, name="test_scorer") + + def teardown_method(self): + """Reset the singleton after each test.""" + ScorerRegistry.reset_instance() + + def test_get_instance_by_name_returns_scorer(self): + """Test getting a registered scorer by name.""" + result = self.registry.get_instance_by_name("test_scorer") + assert result is self.scorer + + def test_get_instance_by_name_nonexistent_returns_none(self): + """Test that getting a non-existent scorer returns None.""" + result = self.registry.get_instance_by_name("nonexistent") + assert result is None + + +class TestScorerRegistryBuildMetadata: + """Tests for _build_metadata functionality in ScorerRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry for each test.""" + ScorerRegistry.reset_instance() + self.registry = ScorerRegistry.get_registry_singleton() + + def teardown_method(self): + """Reset the singleton after each test.""" + ScorerRegistry.reset_instance() + + def test_build_metadata_true_false_scorer(self): + """Test that metadata correctly identifies TrueFalseScorer type.""" + scorer = MockTrueFalseScorer() + self.registry.register_instance(scorer, name="tf_scorer") + + metadata = self.registry.list_metadata() + assert len(metadata) == 1 + assert metadata[0].scorer_type == "true_false" + assert metadata[0].class_name == "MockTrueFalseScorer" + assert metadata[0].name == "tf_scorer" + + def test_build_metadata_float_scale_scorer(self): + """Test that metadata correctly identifies FloatScaleScorer type.""" + scorer = MockFloatScaleScorer() + self.registry.register_instance(scorer, name="fs_scorer") + + metadata = self.registry.list_metadata() + assert len(metadata) == 1 + assert metadata[0].scorer_type == "float_scale" + assert metadata[0].class_name == "MockFloatScaleScorer" + + def test_build_metadata_unknown_scorer_type(self): + """Test that non-standard scorers get 'unknown' scorer_type.""" + scorer = MockGenericScorer() + self.registry.register_instance(scorer, name="generic_scorer") + + metadata = self.registry.list_metadata() + assert len(metadata) == 1 + assert metadata[0].scorer_type == "unknown" + + def test_build_metadata_includes_scorer_identifier(self): + """Test that metadata includes the scorer_identifier.""" + scorer = MockTrueFalseScorer() + self.registry.register_instance(scorer, name="tf_scorer") + + metadata = self.registry.list_metadata() + assert hasattr(metadata[0], "scorer_identifier") + assert isinstance(metadata[0].scorer_identifier, ScorerIdentifier) + + def test_build_metadata_description_from_docstring(self): + """Test that description is derived from the scorer's docstring.""" + scorer = MockTrueFalseScorer() + self.registry.register_instance(scorer, name="tf_scorer") + + metadata = self.registry.list_metadata() + # MockTrueFalseScorer has a docstring + assert "Mock TrueFalseScorer for testing" in metadata[0].description + + +class TestScorerRegistryListMetadataFiltering: + """Tests for list_metadata filtering in ScorerRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry with multiple scorers.""" + ScorerRegistry.reset_instance() + self.registry = ScorerRegistry.get_registry_singleton() + + self.tf_scorer1 = MockTrueFalseScorer() + self.tf_scorer2 = MockTrueFalseScorer() + self.fs_scorer = MockFloatScaleScorer() + + self.registry.register_instance(self.tf_scorer1, name="tf_scorer_1") + self.registry.register_instance(self.tf_scorer2, name="tf_scorer_2") + self.registry.register_instance(self.fs_scorer, name="fs_scorer") + + def teardown_method(self): + """Reset the singleton after each test.""" + ScorerRegistry.reset_instance() + + def test_list_metadata_filter_by_scorer_type(self): + """Test filtering metadata by scorer_type.""" + tf_metadata = self.registry.list_metadata(include_filters={"scorer_type": "true_false"}) + assert len(tf_metadata) == 2 + assert all(m.scorer_type == "true_false" for m in tf_metadata) + + fs_metadata = self.registry.list_metadata(include_filters={"scorer_type": "float_scale"}) + assert len(fs_metadata) == 1 + assert fs_metadata[0].scorer_type == "float_scale" + + def test_list_metadata_filter_by_name(self): + """Test filtering metadata by name.""" + metadata = self.registry.list_metadata(include_filters={"name": "tf_scorer_1"}) + assert len(metadata) == 1 + assert metadata[0].name == "tf_scorer_1" + + def test_list_metadata_no_filter_returns_all(self): + """Test that list_metadata without filters returns all items.""" + metadata = self.registry.list_metadata() + assert len(metadata) == 3 + + def test_list_metadata_exclude_by_scorer_type(self): + """Test excluding metadata by scorer_type.""" + metadata = self.registry.list_metadata(exclude_filters={"scorer_type": "true_false"}) + assert len(metadata) == 1 + assert metadata[0].scorer_type == "float_scale" + + def test_list_metadata_combined_include_and_exclude(self): + """Test combined include and exclude filters.""" + metadata = self.registry.list_metadata( + include_filters={"scorer_type": "true_false"}, exclude_filters={"name": "tf_scorer_1"} + ) + assert len(metadata) == 1 + assert metadata[0].name == "tf_scorer_2" + + +class TestScorerRegistryInheritedMethods: + """Tests for inherited methods from BaseInstanceRegistry.""" + + def setup_method(self): + """Reset and get a fresh registry.""" + ScorerRegistry.reset_instance() + self.registry = ScorerRegistry.get_registry_singleton() + self.scorer = MockTrueFalseScorer() + self.registry.register_instance(self.scorer, name="test_scorer") + + def teardown_method(self): + """Reset the singleton after each test.""" + ScorerRegistry.reset_instance() + + def test_contains_registered_name(self): + """Test __contains__ for registered name.""" + assert "test_scorer" in self.registry + + def test_contains_unregistered_name(self): + """Test __contains__ for unregistered name.""" + assert "unknown_scorer" not in self.registry + + def test_len_returns_count(self): + """Test __len__ returns correct count.""" + assert len(self.registry) == 1 + + def test_iter_yields_names(self): + """Test __iter__ yields registered names.""" + names = list(self.registry) + assert "test_scorer" in names + + def test_get_names_returns_sorted_list(self): + """Test get_names returns sorted list of names.""" + self.registry.register_instance(MockFloatScaleScorer(), name="alpha_scorer") + self.registry.register_instance(MockFloatScaleScorer(), name="zeta_scorer") + + names = self.registry.get_names() + assert names == ["alpha_scorer", "test_scorer", "zeta_scorer"] + + +class TestScorerMetadata: + """Tests for ScorerMetadata dataclass.""" + + def test_scorer_metadata_has_required_fields(self): + """Test that ScorerMetadata includes all required fields.""" + # Create a mock scorer identifier + mock_identifier = ScorerIdentifier(type="test_type") + + metadata = ScorerMetadata( + name="test_scorer", + class_name="TestScorer", + description="A test scorer", + scorer_type="true_false", + scorer_identifier=mock_identifier, + ) + + assert metadata.name == "test_scorer" + assert metadata.class_name == "TestScorer" + assert metadata.description == "A test scorer" + assert metadata.scorer_type == "true_false" + assert metadata.scorer_identifier == mock_identifier diff --git a/tests/unit/setup/test_load_default_datasets.py b/tests/unit/setup/test_load_default_datasets.py index 3c24906b3..68db0d8d0 100644 --- a/tests/unit/setup/test_load_default_datasets.py +++ b/tests/unit/setup/test_load_default_datasets.py @@ -10,10 +10,10 @@ import pytest -from pyrit.cli.scenario_registry import ScenarioRegistry from pyrit.datasets import SeedDatasetProvider from pyrit.memory import CentralMemory from pyrit.models import SeedDataset +from pyrit.registry import ScenarioRegistry from pyrit.scenario.core.scenario import Scenario from pyrit.setup.initializers.scenarios.load_default_datasets import LoadDefaultDatasets @@ -51,7 +51,7 @@ async def test_initialize_async_no_scenarios(self) -> None: """Test initialization when no scenarios are registered.""" initializer = LoadDefaultDatasets() - with patch.object(ScenarioRegistry, "get_scenario_names", return_value=[]): + with patch.object(ScenarioRegistry, "get_names", return_value=[]): with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: with patch.object(CentralMemory, "get_memory_instance") as mock_memory: mock_memory_instance = MagicMock() @@ -75,8 +75,8 @@ async def test_initialize_async_with_scenarios(self) -> None: mock_scenario_class = MagicMock(spec=Scenario) mock_scenario_class.default_dataset_config.return_value = mock_dataset_config - with patch.object(ScenarioRegistry, "get_scenario_names", return_value=["mock_scenario"]): - with patch.object(ScenarioRegistry, "get_scenario", return_value=mock_scenario_class): + with patch.object(ScenarioRegistry, "get_names", return_value=["mock_scenario"]): + with patch.object(ScenarioRegistry, "get_class", return_value=mock_scenario_class): with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: mock_dataset1 = MagicMock(spec=SeedDataset) mock_dataset2 = MagicMock(spec=SeedDataset) @@ -122,8 +122,8 @@ def get_scenario_side_effect(name: str): return mock_scenario2 return None - with patch.object(ScenarioRegistry, "get_scenario_names", return_value=["scenario1", "scenario2"]): - with patch.object(ScenarioRegistry, "get_scenario", side_effect=get_scenario_side_effect): + with patch.object(ScenarioRegistry, "get_names", return_value=["scenario1", "scenario2"]): + with patch.object(ScenarioRegistry, "get_class", side_effect=get_scenario_side_effect): with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: mock_fetch.return_value = [] @@ -162,8 +162,8 @@ def get_scenario_side_effect(name: str): return mock_scenario_bad return None - with patch.object(ScenarioRegistry, "get_scenario_names", return_value=["good_scenario", "bad_scenario"]): - with patch.object(ScenarioRegistry, "get_scenario", side_effect=get_scenario_side_effect): + with patch.object(ScenarioRegistry, "get_names", return_value=["good_scenario", "bad_scenario"]): + with patch.object(ScenarioRegistry, "get_class", side_effect=get_scenario_side_effect): with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: mock_fetch.return_value = [] @@ -191,15 +191,15 @@ async def test_all_required_datasets_available_in_seed_provider(self) -> None: available_datasets = set(SeedDatasetProvider.get_all_dataset_names()) # Get ScenarioRegistry to discover all scenarios - registry = ScenarioRegistry() - scenario_names = registry.get_scenario_names() + registry = ScenarioRegistry.get_registry_singleton() + scenario_names = registry.get_names() # Collect all required datasets from all scenarios missing_datasets: List[str] = [] scenario_dataset_map: dict[str, List[str]] = {} for scenario_name in scenario_names: - scenario_class = registry.get_scenario(scenario_name) + scenario_class = registry.get_class(scenario_name) if scenario_class: try: required = scenario_class.default_dataset_config().get_default_dataset_names() @@ -228,8 +228,8 @@ async def test_initialize_async_empty_dataset_list(self) -> None: mock_scenario = MagicMock(spec=Scenario) mock_scenario.default_dataset_config.return_value = mock_dataset_config - with patch.object(ScenarioRegistry, "get_scenario_names", return_value=["empty_scenario"]): - with patch.object(ScenarioRegistry, "get_scenario", return_value=mock_scenario): + with patch.object(ScenarioRegistry, "get_names", return_value=["empty_scenario"]): + with patch.object(ScenarioRegistry, "get_class", return_value=mock_scenario): with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: with patch.object(CentralMemory, "get_memory_instance") as mock_memory: mock_memory_instance = MagicMock() @@ -247,8 +247,8 @@ async def test_initialize_async_none_scenario_class(self) -> None: """Test initialization when get_scenario returns None for a scenario.""" initializer = LoadDefaultDatasets() - with patch.object(ScenarioRegistry, "get_scenario_names", return_value=["nonexistent_scenario"]): - with patch.object(ScenarioRegistry, "get_scenario", return_value=None): + with patch.object(ScenarioRegistry, "get_names", return_value=["nonexistent_scenario"]): + with patch.object(ScenarioRegistry, "get_class", return_value=None): with patch.object(SeedDatasetProvider, "fetch_datasets_async", new_callable=AsyncMock) as mock_fetch: with patch.object(CentralMemory, "get_memory_instance") as mock_memory: mock_memory_instance = MagicMock()