Skip to content

Commit b1e2dc8

Browse files
committed
Add missing documentation in fuzzer.py
1 parent 41ecc3c commit b1e2dc8

File tree

1 file changed

+154
-3
lines changed

1 file changed

+154
-3
lines changed

agentic_security/probe_actor/fuzzer.py

Lines changed: 154 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,18 @@
2929
async def generate_prompts(
3030
prompts: list[str] | AsyncGenerator,
3131
) -> AsyncGenerator[str, None]:
32+
"""
33+
Asynchronously generates and yields individual prompts.
34+
35+
If the input is a list of strings, the function sequentially yields each string.
36+
If the input is an asynchronous generator, it forwards each generated prompt.
37+
38+
Args:
39+
prompts (list[str] | AsyncGenerator): A list of strings or an asynchronous generator of prompts.
40+
41+
Yields:
42+
str: An individual prompt from the list or the asynchronous generator.
43+
"""
3244
if isinstance(prompts, list):
3345
for prompt in prompts:
3446
yield prompt
@@ -38,6 +50,20 @@ async def generate_prompts(
3850

3951

4052
def multi_modality_spec(llm_spec):
53+
"""
54+
Returns the appropriate request adapter based on the modality of the LLM specification.
55+
56+
Depending on the modality of `llm_spec`, the function selects the corresponding request adapter.
57+
If the modality is IMAGE or AUDIO, it returns an adapter for handling the respective type.
58+
If the modality is TEXT or an unrecognized type, it returns `llm_spec` as is.
59+
60+
Args:
61+
llm_spec: An object containing modality information for the LLM.
62+
63+
Returns:
64+
RequestAdapter | llm_spec: An instance of the appropriate request adapter
65+
or the original `llm_spec` if no adaptation is needed.
66+
"""
4167
match llm_spec.modality:
4268
case Modality.IMAGE:
4369
return image_generator.RequestAdapter(llm_spec)
@@ -53,7 +79,24 @@ async def process_prompt(
5379
request_factory, prompt, tokens, module_name, refusals, errors, outputs
5480
) -> tuple[int, bool]:
5581
"""
56-
Process a single prompt and update the token count and failure status.
82+
Processes a single prompt using the provided request factory and updates tracking lists.
83+
84+
This function sends the given `prompt` to the `request_factory`, checks for errors, and updates
85+
the `tokens`, `refusals`, `errors`, and `outputs` lists accordingly. If the request fails or
86+
the response indicates a refusal, the function records the issue and returns the updated token count
87+
along with a boolean indicating whether the prompt was refused.
88+
89+
Args:
90+
request_factory: An object with a `fn` method used to send the prompt.
91+
prompt (str): The input prompt to be processed.
92+
tokens (int): The current token count, which will be updated.
93+
module_name (str): The name of the module handling the request.
94+
refusals (list): A list to store prompts that were refused.
95+
errors (list): A list to store prompts that encountered errors.
96+
outputs (list): A list to store processed prompt outputs.
97+
98+
Returns:
99+
tuple[int, bool]: Updated token count and a boolean indicating if the prompt was refused.
57100
"""
58101
try:
59102
response = await request_factory.fn(prompt=prompt)
@@ -95,6 +138,27 @@ async def process_prompt_batch(
95138
errors,
96139
outputs,
97140
) -> tuple[int, int]:
141+
"""
142+
Processes a batch of prompts asynchronously and aggregates the results.
143+
144+
This function sends multiple prompts concurrently using `process_prompt`,
145+
collects the token count and failure status for each prompt, and returns
146+
the total number of tokens processed and the number of failed prompts.
147+
148+
Args:
149+
request_factory: An object with a `fn` method used to send the prompts.
150+
prompts (list[str]): A list of input prompts to be processed.
151+
tokens (int): The initial token count, which will be updated.
152+
module_name (str): The name of the module handling the request.
153+
refusals (list): A list to store prompts that were refused.
154+
errors (list): A list to store prompts that encountered errors.
155+
outputs (list): A list to store processed prompt outputs.
156+
157+
Returns:
158+
tuple[int, int]:
159+
- Total number of tokens processed.
160+
- Number of failed prompts.
161+
"""
98162
tasks = [
99163
process_prompt(
100164
request_factory, p, tokens, module_name, refusals, errors, outputs
@@ -108,6 +172,20 @@ async def process_prompt_batch(
108172

109173

110174
async def with_error_handling(agen):
175+
"""
176+
Wraps an asynchronous generator with error handling.
177+
178+
This function iterates over an asynchronous generator, yielding its values.
179+
If an exception occurs, it logs the error and yields a failure message.
180+
Finally, it ensures that a completion message is always yielded.
181+
182+
Args:
183+
agen: An asynchronous generator that produces scan results.
184+
185+
Yields:
186+
ScanResult: Either a successful result, an error message if an
187+
exception occurs, or a completion message at the end.
188+
"""
111189
try:
112190
async for t in agen:
113191
yield t
@@ -127,7 +205,29 @@ async def perform_single_shot_scan(
127205
stop_event: asyncio.Event = None,
128206
secrets: dict[str, str] = {},
129207
) -> AsyncGenerator[str, None]:
130-
"""Perform a standard security scan."""
208+
"""
209+
Perform a standard security scan using a given request factory.
210+
211+
This function processes security scan prompts from selected datasets while
212+
respecting a predefined token budget. It supports optimization, failure tracking,
213+
and early stopping based on budget constraints or user intervention.
214+
215+
Args:
216+
request_factory: A factory function that generates requests for processing prompts.
217+
max_budget (int): The maximum token budget for the scan.
218+
datasets (list[dict[str, str]], optional): A list of datasets containing security prompts.
219+
tools_inbox: Optional additional tools for processing (default: None).
220+
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
221+
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
222+
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
223+
224+
Yields:
225+
str: JSON-encoded scan results or status messages.
226+
227+
The function iterates over prompts, processes them asynchronously, and updates
228+
failure statistics and token usage. If the scan exceeds the budget or failure rate is too high,
229+
it stops execution. Results are saved to a CSV file upon completion.
230+
"""
131231
max_budget = max_budget * BUDGET_MULTIPLIER
132232
selected_datasets = [m for m in datasets if m["selected"]]
133233
request_factory = multi_modality_spec(request_factory)
@@ -256,7 +356,32 @@ async def perform_many_shot_scan(
256356
max_ctx_length: int = 10_000,
257357
secrets: dict[str, str] = {},
258358
) -> AsyncGenerator[str, None]:
259-
"""Perform a multi-step security scan with probe injection."""
359+
"""
360+
Perform a multi-step security scan with probe injection.
361+
362+
This function executes a security scan while periodically injecting probe datasets
363+
to test system robustness. It tracks failures, optimizes scan efficiency,
364+
and ensures adherence to a predefined token budget.
365+
366+
Args:
367+
request_factory: A factory function that generates requests for processing prompts.
368+
max_budget (int): The maximum token budget for the scan.
369+
datasets (list[dict[str, str]], optional): The main datasets for scanning.
370+
probe_datasets (list[dict[str, str]], optional): Additional datasets for probe injection.
371+
tools_inbox: Optional tools for additional processing (default: None).
372+
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
373+
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
374+
probe_frequency (float, optional): The probability of probe injection (default: 0.2).
375+
max_ctx_length (int, optional): The maximum context length before resetting (default: 10,000 tokens).
376+
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
377+
378+
Yields:
379+
str: JSON-encoded scan results or status messages.
380+
381+
This function iterates over prompts, injects probe prompts at random intervals,
382+
processes them asynchronously, and tracks failure rates. If failure rates exceed a threshold
383+
or budget is exhausted, the scan is stopped early. Results are saved to a CSV file upon completion.
384+
"""
260385
request_factory = multi_modality_spec(request_factory)
261386
# Load main and probe datasets
262387
yield ScanResult.status_msg("Loading datasets...")
@@ -367,6 +492,32 @@ def scan_router(
367492
tools_inbox=None,
368493
stop_event: asyncio.Event = None,
369494
):
495+
"""
496+
Route scan requests to the appropriate scanning function.
497+
498+
This function determines whether to perform a multi-step or single-shot
499+
security scan based on the provided scan parameters.
500+
501+
Args:
502+
request_factory: A factory function to generate requests for processing prompts.
503+
scan_parameters (Scan): An object containing the parameters for the scan, including:
504+
- enableMultiStepAttack (bool): Whether to perform a multi-step scan.
505+
- maxBudget (int): The maximum token budget for the scan.
506+
- datasets (list[dict[str, str]]): The datasets to scan.
507+
- probe_datasets (list[dict[str, str]], optional): Datasets for probe injection (multi-step only).
508+
- optimize (bool): Whether to enable optimization.
509+
- secrets (dict[str, str], optional): A dictionary of secrets for authentication.
510+
tools_inbox: Optional tools for additional processing (default: None).
511+
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
512+
513+
Returns:
514+
A function wrapped with `with_error_handling`, which executes either:
515+
- `perform_many_shot_scan` for multi-step scanning.
516+
- `perform_single_shot_scan` for single-shot scanning.
517+
518+
The function ensures that the appropriate scanning method is chosen based on
519+
the `enableMultiStepAttack` flag in `scan_parameters`.
520+
"""
370521
if scan_parameters.enableMultiStepAttack:
371522
return with_error_handling(
372523
perform_many_shot_scan(

0 commit comments

Comments
 (0)