Skip to content

Commit e0eed6f

Browse files
committed
fix(rl_model.Module):
1 parent 21c37b8 commit e0eed6f

File tree

5 files changed

+142
-36
lines changed

5 files changed

+142
-36
lines changed

agentic_security/probe_data/__init__.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,21 @@
408408
},
409409
"modality": "text",
410410
},
411+
{
412+
"dataset_name": "Reinforcement Learning Optimization",
413+
"num_prompts": 0,
414+
"tokens": 0,
415+
"approx_cost": 0.0,
416+
"source": "Cloud hosted model",
417+
"selected": False,
418+
"url": "",
419+
"dynamic": True,
420+
"opts": {
421+
"port": 8718,
422+
"modules": ["encoding"],
423+
},
424+
"modality": "text",
425+
},
411426
{
412427
"dataset_name": "InspectAI",
413428
"num_prompts": 0,

agentic_security/probe_data/data.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
fine_tuned,
1717
garak_tool,
1818
inspect_ai_tool,
19+
rl_model,
1920
)
2021

2122

@@ -265,6 +266,11 @@ def prepare_prompts(dataset_names, budget, tools_inbox=None, options=[]):
265266
garak_tool.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
266267
lazy=True,
267268
),
269+
"Reinforcement Learning Optimization": lambda opts: dataset_from_iterator(
270+
"Reinforcement Learning Optimization",
271+
rl_model.Module(group, tools_inbox=tools_inbox, opts=opts).apply(),
272+
lazy=True,
273+
),
268274
"InspectAI": lambda opts: dataset_from_iterator(
269275
"InspectAI",
270276
inspect_ai_tool.Module(group, tools_inbox=tools_inbox).apply(),

agentic_security/probe_data/modules/rl_model.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import asyncio
12
import os
23
import random
4+
import uuid as U
35
from abc import ABC, abstractmethod
46
from collections import deque
57
from typing import Deque
@@ -78,13 +80,15 @@ def __init__(
7880
auth_token: str = AUTH_TOKEN,
7981
history_size: int = 300,
8082
timeout: int = 5,
83+
run_id: str = "",
8184
):
8285
if not prompts:
8386
raise ValueError("Prompts list cannot be empty")
8487
self.prompts = prompts
8588
self.api_url = api_url
8689
self.headers = {"Authorization": f"Bearer {auth_token}"}
8790
self.timeout = timeout
91+
self.run_id = run_id or U.uuid4().hex
8892

8993
def select_next_prompt(self, current_prompt: str, passed_guard: bool) -> list[str]:
9094
return self.select_next_prompts(current_prompt, passed_guard)[0]
@@ -94,6 +98,7 @@ def select_next_prompts(self, current_prompt: str, passed_guard: bool) -> str:
9498
response = requests.post(
9599
f"{self.api_url}/rl-model/select-next-prompt",
96100
json={
101+
"run_id": U.uuid4().hex,
97102
"current_prompt": current_prompt,
98103
"passed_guard": passed_guard,
99104
},
@@ -115,8 +120,7 @@ def update_rewards(
115120
current_prompt: str,
116121
reward: float,
117122
passed_guard: bool,
118-
) -> None:
119-
...
123+
) -> None: ...
120124

121125

122126
class QLearningPromptSelector(PromptSelectionInterface):
@@ -197,3 +201,46 @@ def update_rewards(
197201

198202
# Update Q-value
199203
self.q_table[previous_prompt][current_prompt] += self.learning_rate * td_error
204+
205+
206+
class Module:
207+
def __init__(
208+
self, prompt_groups: list[str], tools_inbox: asyncio.Queue, opts: dict = {}
209+
):
210+
self.tools_inbox = tools_inbox
211+
self.opts = opts
212+
self.prompt_groups = prompt_groups
213+
self.max_prompts = self.opts.get("max_prompts", 10) # Default max M prompts
214+
self.run_id = U.uuid4().hex
215+
self.batch_size = self.opts.get("batch_size", 500)
216+
self.rl_model = CloudRLPromptSelector(
217+
prompt_groups, "https://edge.metaheuristic.co", run_id=self.run_id
218+
)
219+
220+
async def apply(self):
221+
current_prompt = "What is AI?"
222+
passed_guard = False
223+
for _ in range(max(self.max_prompts, 1)):
224+
# Fetch prompts from the API
225+
prompts = await asyncio.to_thread(
226+
lambda: self.rl_model.select_next_prompts(
227+
current_prompt, passed_guard=passed_guard
228+
)
229+
)
230+
231+
if not prompts:
232+
logger.error("No prompts retrieved from the API.")
233+
return
234+
235+
logger.info(f"Retrieved {len(prompts)} prompts.")
236+
237+
for i, prompt in enumerate(prompts):
238+
logger.info(f"Processing prompt {i+1}/{len(prompts)}: {prompt}")
239+
yield prompt
240+
current_prompt = prompt
241+
while not self.tools_inbox.empty():
242+
ref = await self.tools_inbox.get()
243+
print(ref, "ref")
244+
message, _, ready = ref["message"], ref["reply"], ref["ready"]
245+
yield message
246+
ready.set()

agentic_security/probe_data/modules/test_rl_model.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import asyncio
12
from collections import deque
23
from unittest.mock import Mock, patch
34

@@ -8,6 +9,7 @@
89
# Import the classes to be tested
910
from agentic_security.probe_data.modules.rl_model import (
1011
CloudRLPromptSelector,
12+
Module,
1113
QLearningPromptSelector,
1214
RandomPromptSelector,
1315
)
@@ -30,6 +32,19 @@ def mock_requests() -> Mock:
3032
yield mock_requests
3133

3234

35+
@pytest.fixture
36+
def mock_rl_selector() -> Mock:
37+
return CloudRLPromptSelector(
38+
dataset_prompts,
39+
api_url="https://edge.metaheuristic.co",
40+
)
41+
42+
43+
@pytest.fixture
44+
def tools_inbox() -> asyncio.Queue:
45+
return asyncio.Queue()
46+
47+
3348
# Tests for RandomPromptSelector
3449
class TestRandomPromptSelector:
3550
def test_initialization(self, dataset_prompts):
@@ -141,3 +156,60 @@ def test_cloud_rl_selector_invalid_url(dataset_prompts):
141156
def test_q_learning_selector_invalid_reward(dataset_prompts):
142157
selector = QLearningPromptSelector(dataset_prompts)
143158
selector.update_rewards("What is AI?", "How does RL work?", np.nan, True)
159+
160+
161+
# Tests for Module class
162+
class TestModule:
163+
@pytest.fixture
164+
def mock_uuid(self):
165+
with patch("uuid.uuid4") as mock:
166+
mock.return_value.hex = "test_run_id"
167+
yield mock
168+
169+
def test_initialization(self, dataset_prompts, tools_inbox, mock_uuid):
170+
module = Module(dataset_prompts, tools_inbox)
171+
assert module.prompt_groups == dataset_prompts
172+
assert module.tools_inbox == tools_inbox
173+
assert module.max_prompts == 2000
174+
assert module.batch_size == 500
175+
assert module.run_id == "test_run_id"
176+
assert isinstance(module.rl_model, CloudRLPromptSelector)
177+
178+
def test_initialization_with_options(self, dataset_prompts, tools_inbox, mock_uuid):
179+
opts = {
180+
"max_prompts": 100,
181+
"batch_size": 50,
182+
}
183+
module = Module(dataset_prompts, tools_inbox, opts)
184+
assert module.max_prompts == 100
185+
assert module.batch_size == 50
186+
187+
@pytest.mark.asyncio
188+
async def test_apply_basic_flow(
189+
self, dataset_prompts, tools_inbox, mock_rl_selector
190+
):
191+
module = Module(dataset_prompts, tools_inbox)
192+
193+
count = 0
194+
async for prompt in module.apply():
195+
assert prompt == "Test prompt"
196+
count += 1
197+
if count >= 3: # Test a few iterations
198+
break
199+
200+
@pytest.mark.asyncio
201+
async def test_apply_rl_with_tools_inbox(self, dataset_prompts, tools_inbox):
202+
# Add a test message to the tools inbox
203+
test_message = {
204+
"message": "Test message",
205+
"reply": None,
206+
"ready": asyncio.Event(),
207+
}
208+
await tools_inbox.put(test_message)
209+
210+
module = Module(dataset_prompts, tools_inbox)
211+
212+
async for output in module.apply():
213+
if output == "Test message":
214+
test_message["ready"].set()
215+
break

agentic_security/test_registry.py

Lines changed: 0 additions & 34 deletions
This file was deleted.

0 commit comments

Comments
 (0)