Skip to content

Commit ac3f2f8

Browse files
committed
feat(move optimizer to module lvl):
1 parent bd6d2f3 commit ac3f2f8

File tree

1 file changed

+13
-26
lines changed

1 file changed

+13
-26
lines changed

agentic_security/probe_actor/fuzzer.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,6 @@ async def scan_module(
153153
max_budget: int = 0,
154154
total_tokens: int = 0,
155155
optimize: bool = False,
156-
optimizer=None,
157156
stop_event: asyncio.Event | None = None,
158157
) -> AsyncGenerator[dict[str, Any], None]:
159158
"""
@@ -168,7 +167,6 @@ async def scan_module(
168167
max_budget: Maximum token budget
169168
total_tokens: Current token count
170169
optimize: Whether to use optimization
171-
optimizer: The optimizer to use
172170
stop_event: Event to stop scanning
173171
174172
Yields:
@@ -180,6 +178,15 @@ async def scan_module(
180178
failure_rates = []
181179
should_stop = False
182180

181+
# Initialize optimizer if optimization is enabled
182+
optimizer = (
183+
Optimizer(
184+
[Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS
185+
)
186+
if optimize
187+
else None
188+
)
189+
183190
module_size = 0 if module.lazy else len(module.prompts)
184191
logger.info(f"Scanning {module.dataset_name} {module_size}")
185192

@@ -313,14 +320,6 @@ async def perform_single_shot_scan(
313320
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
314321
processed_prompts = 0
315322

316-
optimizer = (
317-
Optimizer(
318-
[Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS
319-
)
320-
if optimize
321-
else None
322-
)
323-
324323
total_tokens = 0
325324
for module in prompt_modules:
326325
module_gen = scan_module(
@@ -332,7 +331,6 @@ async def perform_single_shot_scan(
332331
max_budget=max_budget,
333332
total_tokens=total_tokens,
334333
optimize=optimize,
335-
optimizer=optimizer,
336334
stop_event=stop_event,
337335
)
338336
try:
@@ -396,13 +394,6 @@ async def perform_many_shot_scan(
396394
total_prompts = sum(len(m.prompts) for m in prompt_modules if not m.lazy)
397395
processed_prompts = 0
398396

399-
optimizer = (
400-
Optimizer(
401-
[Real(0, 1)], base_estimator="GP", n_initial_points=INITIAL_OPTIMIZER_POINTS
402-
)
403-
if optimize
404-
else None
405-
)
406397
failure_rates = []
407398

408399
for module in prompt_modules:
@@ -466,14 +457,10 @@ async def perform_many_shot_scan(
466457
).model_dump_json()
467458

468459
if optimize and len(failure_rates) >= MIN_FAILURE_SAMPLES:
469-
next_point = optimizer.ask()
470-
optimizer.tell(next_point, -failure_rate)
471-
best_failure_rate = -optimizer.get_result().fun
472-
if best_failure_rate > FAILURE_RATE_THRESHOLD:
473-
yield ScanResult.status_msg(
474-
f"High failure rate detected ({best_failure_rate:.2%}). Stopping this module..."
475-
)
476-
break
460+
yield ScanResult.status_msg(
461+
f"High failure rate detected ({failure_rate:.2%}). Stopping this module..."
462+
)
463+
break
477464

478465
yield ScanResult.status_msg("Scan completed.")
479466
fuzzer_state.export_failures("failures.csv")

0 commit comments

Comments
 (0)