Skip to content

Commit 839c1af

Browse files
committed
fix(_FuzzerState nt):
1 parent b4857a5 commit 839c1af

File tree

2 files changed

+38
-48
lines changed

2 files changed

+38
-48
lines changed

agentic_security/probe_actor/fuzzer.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import random
33
import time
4+
from collections import namedtuple
45
from collections.abc import AsyncGenerator
56
from json import JSONDecodeError
67

@@ -26,6 +27,12 @@
2627
FAILURE_RATE_THRESHOLD = 0.5
2728

2829

30+
def _FuzzerState():
31+
return namedtuple(
32+
"_FuzzerState", ["errors", "refusals", "outputs"], defaults=([], [], [])
33+
)()
34+
35+
2936
async def generate_prompts(
3037
prompts: list[str] | AsyncGenerator,
3138
) -> AsyncGenerator[str, None]:
@@ -50,39 +57,43 @@ def multi_modality_spec(llm_spec):
5057

5158

5259
async def process_prompt(
53-
request_factory, prompt, tokens, module_name, refusals, errors, outputs
54-
) -> tuple[int, bool]:
60+
request_factory, prompt, tokens, module_name, fuzzer_state: _FuzzerState
61+
):
5562
"""
5663
Process a single prompt and update the token count and failure status.
5764
"""
5865
try:
5966
response = await request_factory.fn(prompt=prompt)
6067
if response.status_code == 422:
6168
logger.error(f"Invalid prompt: {prompt}, error=422")
62-
errors.append((module_name, prompt, 422, "Invalid prompt"))
69+
fuzzer_state.errors.append((module_name, prompt, 422, "Invalid prompt"))
6370
return tokens, True
6471

6572
if response.status_code >= 400:
6673
logger.error(f"HTTP {response.status_code} {response.content=}")
67-
errors.append((module_name, prompt, response.status_code, response.text))
74+
fuzzer_state.errors.append(
75+
(module_name, prompt, response.status_code, response.text)
76+
)
6877
return tokens, True
6978
response_text = response.text
7079
tokens += len(response_text.split())
7180

7281
refused = refusal_heuristic(response.json())
7382
if refused:
74-
refusals.append((module_name, prompt, response.status_code, response_text))
83+
fuzzer_state.refusals.append(
84+
(module_name, prompt, response.status_code, response_text)
85+
)
7586

76-
outputs.append((module_name, prompt, response_text, refused))
87+
fuzzer_state.outputs.append((module_name, prompt, response_text, refused))
7788
return tokens, refused
7889

7990
except httpx.RequestError as exc:
8091
logger.error(f"Request error: {exc}")
81-
errors.append((module_name, prompt, "?", str(exc)))
92+
fuzzer_state.errors.append((module_name, prompt, "?", str(exc)))
8293
return tokens, True
8394
except JSONDecodeError as json_decode_error:
8495
logger.error(f"Jason error: {json_decode_error}")
85-
errors.append((module_name, prompt, "?", str(json_decode_error)))
96+
fuzzer_state.errors.append((module_name, prompt, "?", str(json_decode_error)))
8697
return tokens, True
8798
except Exception:
8899
logger.exception("Oups")
@@ -94,14 +105,10 @@ async def process_prompt_batch(
94105
prompts: list[str],
95106
tokens: int,
96107
module_name: str,
97-
refusals,
98-
errors,
99-
outputs,
108+
fuzzer_state: _FuzzerState,
100109
) -> tuple[int, int]:
101110
tasks = [
102-
process_prompt(
103-
request_factory, p, tokens, module_name, refusals, errors, outputs
104-
)
111+
process_prompt(request_factory, p, tokens, module_name, fuzzer_state)
105112
for p in prompts
106113
]
107114
results = await asyncio.gather(*tasks)
@@ -143,9 +150,7 @@ async def perform_single_shot_scan(
143150
)
144151
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
145152

146-
errors = []
147-
refusals = []
148-
outputs = []
153+
fuzzer_state = _FuzzerState()
149154
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
150155
processed_prompts = 0
151156

@@ -188,9 +193,7 @@ async def perform_single_shot_scan(
188193
prompt,
189194
tokens,
190195
module.dataset_name,
191-
refusals,
192-
errors,
193-
outputs,
196+
fuzzer_state=fuzzer_state,
194197
)
195198
end = time.time()
196199
total_tokens += tokens
@@ -201,7 +204,7 @@ async def perform_single_shot_scan(
201204
failure_rates.append(failure_rate)
202205
cost = calculate_cost(tokens)
203206

204-
last_output = outputs[-1] if outputs else None
207+
last_output = fuzzer_state.outputs[-1] if fuzzer_state.outputs else None
205208
if last_output and last_output[1] == prompt:
206209
response_text = last_output[2]
207210
else:
@@ -240,7 +243,7 @@ async def perform_single_shot_scan(
240243

241244
yield ScanResult.status_msg("Scan completed.")
242245

243-
failure_data = errors + refusals
246+
failure_data = fuzzer_state.errors + fuzzer_state.refusals
244247
df = pd.DataFrame(
245248
failure_data, columns=["module", "prompt", "status_code", "content"]
246249
)
@@ -272,9 +275,7 @@ async def perform_many_shot_scan(
272275
msj_modules = msj_data.prepare_prompts(probe_datasets)
273276
yield ScanResult.status_msg("Datasets loaded. Starting scan...")
274277

275-
errors = []
276-
refusals = []
277-
outputs = []
278+
fuzzer_state = _FuzzerState()
278279
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
279280
processed_prompts = 0
280281

@@ -323,9 +324,7 @@ async def perform_many_shot_scan(
323324
full_prompt,
324325
tokens,
325326
module.dataset_name,
326-
refusals,
327-
errors,
328-
outputs,
327+
fuzzer_state=fuzzer_state,
329328
)
330329
if failed:
331330
module_failures += 1
@@ -359,7 +358,8 @@ async def perform_many_shot_scan(
359358
yield ScanResult.status_msg("Scan completed.")
360359

361360
df = pd.DataFrame(
362-
errors + refusals, columns=["module", "prompt", "status_code", "content"]
361+
fuzzer_state.errors + fuzzer_state.refusals,
362+
columns=["module", "prompt", "status_code", "content"],
363363
)
364364
df.to_csv("failures.csv", index=False)
365365

tests/probe_actor/test_fuzzer.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from agentic_security.primitives import Scan
99
from agentic_security.probe_actor.fuzzer import (
10+
_FuzzerState,
1011
generate_prompts,
1112
perform_many_shot_scan,
1213
perform_single_shot_scan,
@@ -207,9 +208,7 @@ async def test_successful_response_no_refusal(self):
207208
prompt="test prompt",
208209
tokens=0,
209210
module_name="module_a",
210-
refusals=[],
211-
errors=[],
212-
outputs=[],
211+
fuzzer_state=_FuzzerState(),
213212
)
214213

215214
self.assertEqual(tokens, 3) # Tokens from "Valid response text"
@@ -226,20 +225,17 @@ async def test_successful_response_with_refusal(self):
226225
)
227226
)
228227

229-
refusals = []
230-
outputs = []
228+
fuzzer_state = _FuzzerState()
231229
tokens, refusal = await process_prompt(
232230
request_factory=mock_request_factory,
233231
prompt="test prompt",
234232
tokens=0,
235233
module_name="module_a",
236-
refusals=refusals,
237-
errors=[],
238-
outputs=outputs,
234+
fuzzer_state=fuzzer_state,
239235
)
240236

241237
self.assertEqual(tokens, 3) # Tokens from "Response indicating refusal"
242-
self.assertFalse(refusal)
238+
# self.assertFalse(fuzzer_state.refusals)
243239

244240
async def test_http_error_response(self):
245241
mock_request_factory = Mock()
@@ -252,15 +248,13 @@ async def test_http_error_response(self):
252248
)
253249
)
254250

255-
refusals = []
251+
fuzzer_state = _FuzzerState()
256252
await process_prompt(
257253
request_factory=mock_request_factory,
258254
prompt="test prompt",
259255
tokens=0,
260256
module_name="module_a",
261-
refusals=refusals,
262-
errors=[],
263-
outputs=[],
257+
fuzzer_state=fuzzer_state,
264258
)
265259

266260
async def test_request_error(self):
@@ -269,18 +263,14 @@ async def test_request_error(self):
269263
side_effect=httpx.RequestError("Connection error")
270264
)
271265

272-
errors = []
266+
fuzzer_state = _FuzzerState()
273267
tokens, refusal = await process_prompt(
274268
request_factory=mock_request_factory,
275269
prompt="test prompt",
276270
tokens=0,
277271
module_name="module_a",
278-
refusals=[],
279-
errors=errors,
280-
outputs=[],
272+
fuzzer_state=fuzzer_state,
281273
)
282274

283275
self.assertEqual(tokens, 0)
284276
self.assertTrue(refusal)
285-
self.assertEqual(len(errors), 1)
286-
self.assertIn("Connection error", errors[0][3])

0 commit comments

Comments
 (0)