|
| 1 | +# RL Model Module |
| 2 | + |
| 3 | +The RL Model module provides reinforcement learning-based prompt selection strategies for the probe system. |
| 4 | + |
| 5 | +## Overview |
| 6 | + |
| 7 | +The module implements several prompt selection strategies that use reinforcement learning techniques to optimize prompt selection based on guard results and rewards. |
| 8 | + |
| 9 | +## Classes |
| 10 | + |
| 11 | +### PromptSelectionInterface |
| 12 | + |
| 13 | +Abstract base class defining the interface for prompt selection strategies. |
| 14 | + |
| 15 | +**Methods:** |
| 16 | + |
| 17 | +- `select_next_prompt(current_prompt: str, passed_guard: bool) -> str` |
| 18 | +- `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]` |
| 19 | +- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> None` |
| 20 | + |
| 21 | +### RandomPromptSelector |
| 22 | + |
| 23 | +Basic random selection strategy with cycle prevention using history. |
| 24 | + |
| 25 | +**Configuration:** |
| 26 | + |
| 27 | +- `prompts`: List of available prompts |
| 28 | +- `history_size`: Size of history buffer to prevent cycles (default: 300) |
| 29 | + |
| 30 | +### CloudRLPromptSelector |
| 31 | + |
| 32 | +Cloud-based reinforcement learning prompt selector with fallback to random selection. |
| 33 | + |
| 34 | +**Configuration:** |
| 35 | + |
| 36 | +- `prompts`: List of available prompts |
| 37 | +- `api_url`: URL of the RL service |
| 38 | +- `auth_token`: Authentication token (default: AS_TOKEN environment variable) |
| 39 | +- `history_size`: Size of history buffer (default: 300) |
| 40 | +- `timeout`: Request timeout in seconds (default: 5) |
| 41 | +- `run_id`: Unique identifier for the run |
| 42 | + |
| 43 | +### QLearningPromptSelector |
| 44 | + |
| 45 | +Q-Learning based prompt selector with exploration/exploitation tradeoff. |
| 46 | + |
| 47 | +**Configuration:** |
| 48 | + |
| 49 | +- `prompts`: List of available prompts |
| 50 | +- `learning_rate`: Learning rate (default: 0.1) |
| 51 | +- `discount_factor`: Discount factor (default: 0.9) |
| 52 | +- `initial_exploration`: Initial exploration rate (default: 1.0) |
| 53 | +- `exploration_decay`: Exploration decay rate (default: 0.995) |
| 54 | +- `min_exploration`: Minimum exploration rate (default: 0.01) |
| 55 | +- `history_size`: Size of history buffer (default: 300) |
| 56 | + |
| 57 | +### Module |
| 58 | + |
| 59 | +Main class that implements the RL-based prompt selection functionality. |
| 60 | + |
| 61 | +**Configuration:** |
| 62 | + |
| 63 | +- `prompt_groups`: List of prompt groups |
| 64 | +- `tools_inbox`: asyncio.Queue for tool communication |
| 65 | +- `opts`: Additional options |
| 66 | + - `max_prompts`: Maximum number of prompts to generate (default: 10) |
| 67 | + - `batch_size`: Batch size for processing (default: 500) |
| 68 | + |
| 69 | +## Usage Example |
| 70 | + |
| 71 | +```python |
| 72 | +from agentic_security.probe_data.modules.rl_model import ( |
| 73 | + Module, |
| 74 | + CloudRLPromptSelector, |
| 75 | + QLearningPromptSelector |
| 76 | +) |
| 77 | + |
| 78 | +# Initialize with prompt groups |
| 79 | +prompt_groups = ["What is AI?", "Explain ML", "Describe RL"] |
| 80 | +module = Module(prompt_groups, asyncio.Queue()) |
| 81 | + |
| 82 | +# Use the module |
| 83 | +async for prompt in module.apply(): |
| 84 | + print(f"Selected prompt: {prompt}") |
| 85 | +``` |
| 86 | + |
| 87 | +## API Reference |
| 88 | + |
| 89 | +### PromptSelectionInterface |
| 90 | + |
| 91 | +```python |
| 92 | +class PromptSelectionInterface(ABC): |
| 93 | + @abstractmethod |
| 94 | + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: |
| 95 | + """Select next prompt based on current state and guard result.""" |
| 96 | + |
| 97 | + @abstractmethod |
| 98 | + def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]: |
| 99 | + """Select next prompts based on current state and guard result.""" |
| 100 | + |
| 101 | + @abstractmethod |
| 102 | + def update_rewards( |
| 103 | + self, |
| 104 | + previous_prompt: str, |
| 105 | + current_prompt: str, |
| 106 | + reward: float, |
| 107 | + passed_guard: bool, |
| 108 | + ) -> None: |
| 109 | + """Update internal rewards based on outcome of last selected prompt.""" |
| 110 | +``` |
| 111 | + |
| 112 | +### RandomPromptSelector |
| 113 | + |
| 114 | +```python |
| 115 | +class RandomPromptSelector(PromptSelectionInterface): |
| 116 | + def __init__(self, prompts: list[str], history_size: int = 300): |
| 117 | + """Initialize with prompts and history size.""" |
| 118 | + |
| 119 | + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: |
| 120 | + """Select next prompt randomly with cycle prevention.""" |
| 121 | + |
| 122 | + def update_rewards( |
| 123 | + self, |
| 124 | + previous_prompt: str, |
| 125 | + current_prompt: str, |
| 126 | + reward: float, |
| 127 | + passed_guard: bool, |
| 128 | + ) -> None: |
| 129 | + """No learning in random selection.""" |
| 130 | +``` |
| 131 | + |
| 132 | +### CloudRLPromptSelector |
| 133 | + |
| 134 | +```python |
| 135 | +class CloudRLPromptSelector(PromptSelectionInterface): |
| 136 | + def __init__( |
| 137 | + self, |
| 138 | + prompts: list[str], |
| 139 | + api_url: str, |
| 140 | + auth_token: str = AUTH_TOKEN, |
| 141 | + history_size: int = 300, |
| 142 | + timeout: int = 5, |
| 143 | + run_id: str = "", |
| 144 | + ): |
| 145 | + """Initialize with cloud RL configuration.""" |
| 146 | + |
| 147 | + def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]: |
| 148 | + """Select next prompts using cloud RL with fallback.""" |
| 149 | + |
| 150 | + def _fallback_selection(self) -> str: |
| 151 | + """Fallback to random selection if cloud request fails.""" |
| 152 | +``` |
| 153 | + |
| 154 | +### QLearningPromptSelector |
| 155 | + |
| 156 | +```python |
| 157 | +class QLearningPromptSelector(PromptSelectionInterface): |
| 158 | + def __init__( |
| 159 | + self, |
| 160 | + prompts: list[str], |
| 161 | + learning_rate: float = 0.1, |
| 162 | + discount_factor: float = 0.9, |
| 163 | + initial_exploration: float = 1.0, |
| 164 | + exploration_decay: float = 0.995, |
| 165 | + min_exploration: float = 0.01, |
| 166 | + history_size: int = 300, |
| 167 | + ): |
| 168 | + """Initialize Q-Learning configuration.""" |
| 169 | + |
| 170 | + def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str: |
| 171 | + """Select next prompt using Q-Learning with exploration/exploitation.""" |
| 172 | + |
| 173 | + def update_rewards( |
| 174 | + self, |
| 175 | + previous_prompt: str, |
| 176 | + current_prompt: str, |
| 177 | + reward: float, |
| 178 | + passed_guard: bool, |
| 179 | + ) -> None: |
| 180 | + """Update Q-values based on reward.""" |
| 181 | +``` |
| 182 | + |
| 183 | +### Module |
| 184 | + |
| 185 | +```python |
| 186 | +class Module: |
| 187 | + def __init__( |
| 188 | + self, prompt_groups: list[str], tools_inbox: asyncio.Queue, opts: dict = {} |
| 189 | + ): |
| 190 | + """Initialize module with prompt groups and configuration.""" |
| 191 | + |
| 192 | + async def apply(self): |
| 193 | + """Apply the RL model to generate prompts.""" |
| 194 | +``` |
0 commit comments