Skip to content

Commit eb27f7b

Browse files
committed
feat(add \Reinforcement Learning Optimization doc):
1 parent e0eed6f commit eb27f7b

File tree

4 files changed

+255
-2
lines changed

4 files changed

+255
-2
lines changed

agentic_security/probe_data/modules/rl_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def update_rewards(
4141
class RandomPromptSelector(PromptSelectionInterface):
4242
"""Random prompt selector with cycle prevention using history."""
4343

44-
def __init__(self, prompts: list[str], history_size: int = 3):
44+
def __init__(self, prompts: list[str], history_size: int = 300):
4545
if not prompts:
4646
raise ValueError("Prompts list cannot be empty")
4747
self.prompts = prompts
@@ -120,7 +120,8 @@ def update_rewards(
120120
current_prompt: str,
121121
reward: float,
122122
passed_guard: bool,
123-
) -> None: ...
123+
) -> None:
124+
...
124125

125126

126127
class QLearningPromptSelector(PromptSelectionInterface):

docs/probe_data.md

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,50 @@ The `probe_data` module is a core component of the Agentic Security project, res
5050
- `base64_encode(data)`: Encodes data in base64 format.
5151
- `mirror_words(text)`: Mirrors words in the text.
5252

53+
### rl_model.py
54+
55+
- **Classes:**
56+
- `PromptSelectionInterface`: Abstract base class for prompt selection strategies.
57+
58+
- Methods:
59+
- `select_next_prompt(current_prompt: str, passed_guard: bool) -> str`: Selects next prompt
60+
- `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]`: Selects multiple prompts
61+
- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> None`: Updates rewards
62+
63+
- `RandomPromptSelector`: Basic random selection with history tracking.
64+
65+
- Parameters:
66+
- `prompts: list[str]`: List of available prompts
67+
- `history_size: int = 3`: Size of history to prevent cycles
68+
69+
- `CloudRLPromptSelector`: Cloud-based RL implementation with fallback.
70+
71+
- Parameters:
72+
- `prompts: list[str]`: List of available prompts
73+
- `api_url: str`: URL of RL service
74+
- `auth_token: str = AUTH_TOKEN`: Authentication token
75+
- `history_size: int = 300`: Size of history
76+
- `timeout: int = 5`: Request timeout
77+
- `run_id: str = ""`: Unique run identifier
78+
79+
- `QLearningPromptSelector`: Local Q-learning implementation.
80+
81+
- Parameters:
82+
- `prompts: list[str]`: List of available prompts
83+
- `learning_rate: float = 0.1`: Learning rate
84+
- `discount_factor: float = 0.9`: Discount factor
85+
- `initial_exploration: float = 1.0`: Initial exploration rate
86+
- `exploration_decay: float = 0.995`: Exploration decay rate
87+
- `min_exploration: float = 0.01`: Minimum exploration rate
88+
- `history_size: int = 300`: Size of history
89+
90+
- `Module`: Main class that uses CloudRLPromptSelector.
91+
92+
- Parameters:
93+
- `prompt_groups: list[str]`: Groups of prompts
94+
- `tools_inbox: asyncio.Queue`: Queue for tool communication
95+
- `opts: dict = {}`: Configuration options
96+
5397
## Usage Examples
5498

5599
### Generating Audio
@@ -68,6 +112,19 @@ from agentic_security.probe_data.data import load_dataset_general
68112
dataset = load_dataset_general("example_dataset")
69113
```
70114

115+
### Using RL Model
116+
117+
```python
118+
from agentic_security.probe_data.modules.rl_model import QLearningPromptSelector
119+
120+
prompts = ["What is AI?", "Explain machine learning"]
121+
selector = QLearningPromptSelector(prompts)
122+
123+
current_prompt = "What is AI?"
124+
next_prompt = selector.select_next_prompt(current_prompt, passed_guard=True)
125+
selector.update_rewards(current_prompt, next_prompt, reward=1.0, passed_guard=True)
126+
```
127+
71128
## Conclusion
72129

73130
The `probe_data` module provides essential functionality for handling and transforming datasets within the Agentic Security project. This documentation serves as a guide to understanding and utilizing the module's capabilities.

docs/rl_model.md

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
# RL Model Module
2+
3+
The RL Model module provides reinforcement learning-based prompt selection strategies for the probe system.
4+
5+
## Overview
6+
7+
The module implements several prompt selection strategies that use reinforcement learning techniques to optimize prompt selection based on guard results and rewards.
8+
9+
## Classes
10+
11+
### PromptSelectionInterface
12+
13+
Abstract base class defining the interface for prompt selection strategies.
14+
15+
**Methods:**
16+
17+
- `select_next_prompt(current_prompt: str, passed_guard: bool) -> str`
18+
- `select_next_prompts(current_prompt: str, passed_guard: bool) -> list[str]`
19+
- `update_rewards(previous_prompt: str, current_prompt: str, reward: float, passed_guard: bool) -> None`
20+
21+
### RandomPromptSelector
22+
23+
Basic random selection strategy with cycle prevention using history.
24+
25+
**Configuration:**
26+
27+
- `prompts`: List of available prompts
28+
- `history_size`: Size of history buffer to prevent cycles (default: 300)
29+
30+
### CloudRLPromptSelector
31+
32+
Cloud-based reinforcement learning prompt selector with fallback to random selection.
33+
34+
**Configuration:**
35+
36+
- `prompts`: List of available prompts
37+
- `api_url`: URL of the RL service
38+
- `auth_token`: Authentication token (default: AS_TOKEN environment variable)
39+
- `history_size`: Size of history buffer (default: 300)
40+
- `timeout`: Request timeout in seconds (default: 5)
41+
- `run_id`: Unique identifier for the run
42+
43+
### QLearningPromptSelector
44+
45+
Q-Learning based prompt selector with exploration/exploitation tradeoff.
46+
47+
**Configuration:**
48+
49+
- `prompts`: List of available prompts
50+
- `learning_rate`: Learning rate (default: 0.1)
51+
- `discount_factor`: Discount factor (default: 0.9)
52+
- `initial_exploration`: Initial exploration rate (default: 1.0)
53+
- `exploration_decay`: Exploration decay rate (default: 0.995)
54+
- `min_exploration`: Minimum exploration rate (default: 0.01)
55+
- `history_size`: Size of history buffer (default: 300)
56+
57+
### Module
58+
59+
Main class that implements the RL-based prompt selection functionality.
60+
61+
**Configuration:**
62+
63+
- `prompt_groups`: List of prompt groups
64+
- `tools_inbox`: asyncio.Queue for tool communication
65+
- `opts`: Additional options
66+
- `max_prompts`: Maximum number of prompts to generate (default: 10)
67+
- `batch_size`: Batch size for processing (default: 500)
68+
69+
## Usage Example
70+
71+
```python
72+
from agentic_security.probe_data.modules.rl_model import (
73+
Module,
74+
CloudRLPromptSelector,
75+
QLearningPromptSelector
76+
)
77+
78+
# Initialize with prompt groups
79+
prompt_groups = ["What is AI?", "Explain ML", "Describe RL"]
80+
module = Module(prompt_groups, asyncio.Queue())
81+
82+
# Use the module
83+
async for prompt in module.apply():
84+
print(f"Selected prompt: {prompt}")
85+
```
86+
87+
## API Reference
88+
89+
### PromptSelectionInterface
90+
91+
```python
92+
class PromptSelectionInterface(ABC):
93+
@abstractmethod
94+
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
95+
"""Select next prompt based on current state and guard result."""
96+
97+
@abstractmethod
98+
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]:
99+
"""Select next prompts based on current state and guard result."""
100+
101+
@abstractmethod
102+
def update_rewards(
103+
self,
104+
previous_prompt: str,
105+
current_prompt: str,
106+
reward: float,
107+
passed_guard: bool,
108+
) -> None:
109+
"""Update internal rewards based on outcome of last selected prompt."""
110+
```
111+
112+
### RandomPromptSelector
113+
114+
```python
115+
class RandomPromptSelector(PromptSelectionInterface):
116+
def __init__(self, prompts: list[str], history_size: int = 300):
117+
"""Initialize with prompts and history size."""
118+
119+
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
120+
"""Select next prompt randomly with cycle prevention."""
121+
122+
def update_rewards(
123+
self,
124+
previous_prompt: str,
125+
current_prompt: str,
126+
reward: float,
127+
passed_guard: bool,
128+
) -> None:
129+
"""No learning in random selection."""
130+
```
131+
132+
### CloudRLPromptSelector
133+
134+
```python
135+
class CloudRLPromptSelector(PromptSelectionInterface):
136+
def __init__(
137+
self,
138+
prompts: list[str],
139+
api_url: str,
140+
auth_token: str = AUTH_TOKEN,
141+
history_size: int = 300,
142+
timeout: int = 5,
143+
run_id: str = "",
144+
):
145+
"""Initialize with cloud RL configuration."""
146+
147+
def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> list[str]:
148+
"""Select next prompts using cloud RL with fallback."""
149+
150+
def _fallback_selection(self) -> str:
151+
"""Fallback to random selection if cloud request fails."""
152+
```
153+
154+
### QLearningPromptSelector
155+
156+
```python
157+
class QLearningPromptSelector(PromptSelectionInterface):
158+
def __init__(
159+
self,
160+
prompts: list[str],
161+
learning_rate: float = 0.1,
162+
discount_factor: float = 0.9,
163+
initial_exploration: float = 1.0,
164+
exploration_decay: float = 0.995,
165+
min_exploration: float = 0.01,
166+
history_size: int = 300,
167+
):
168+
"""Initialize Q-Learning configuration."""
169+
170+
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> str:
171+
"""Select next prompt using Q-Learning with exploration/exploitation."""
172+
173+
def update_rewards(
174+
self,
175+
previous_prompt: str,
176+
current_prompt: str,
177+
reward: float,
178+
passed_guard: bool,
179+
) -> None:
180+
"""Update Q-values based on reward."""
181+
```
182+
183+
### Module
184+
185+
```python
186+
class Module:
187+
def __init__(
188+
self, prompt_groups: list[str], tools_inbox: asyncio.Queue, opts: dict = {}
189+
):
190+
"""Initialize module with prompt groups and configuration."""
191+
192+
async def apply(self):
193+
"""Apply the RL model to generate prompts."""
194+
```

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ nav:
2525
- Bayesian Optimization: optimizer.md
2626
- Image Generation: image_generation.md
2727
- Stenography Functions: stenography.md
28+
- Reinforcement Learning Optimization: rl_model.md
2829
- Reference:
2930
- API Reference: api_reference.md
3031
- Community:

0 commit comments

Comments
 (0)