Skip to content

Commit 4332e4a

Browse files
authored
Merge pull request #182 from nemanjaASE/issue-166-missing-documentation
Add missing documentation in fuzzer.py
2 parents b9802fd + e871443 commit 4332e4a

File tree

1 file changed

+119
-42
lines changed

1 file changed

+119
-42
lines changed

agentic_security/probe_actor/fuzzer.py

Lines changed: 119 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,18 @@
2929
async def generate_prompts(
3030
prompts: list[str] | AsyncGenerator,
3131
) -> AsyncGenerator[str, None]:
32-
"""Convert list of prompts or async generator to a unified async generator."""
32+
"""
33+
Asynchronously generates and yields individual prompts.
34+
35+
If the input is a list of strings, the function sequentially yields each string.
36+
If the input is an asynchronous generator, it forwards each generated prompt.
37+
38+
Args:
39+
prompts (list[str] | AsyncGenerator): A list of strings or an asynchronous generator of prompts.
40+
41+
Yields:
42+
str: An individual prompt from the list or the asynchronous generator.
43+
"""
3344
if isinstance(prompts, list):
3445
for prompt in prompts:
3546
yield prompt
@@ -39,7 +50,20 @@ async def generate_prompts(
3950

4051

4152
def get_modality_adapter(llm_spec):
42-
"""Get the appropriate modality adapter based on the LLM spec."""
53+
"""
54+
Returns the appropriate request adapter based on the modality of the LLM specification.
55+
56+
Depending on the modality of `llm_spec`, the function selects the corresponding request adapter.
57+
If the modality is IMAGE or AUDIO, it returns an adapter for handling the respective type.
58+
If the modality is TEXT or an unrecognized type, it returns `llm_spec` as is.
59+
60+
Args:
61+
llm_spec: An object containing modality information for the LLM.
62+
63+
Returns:
64+
RequestAdapter | llm_spec: An instance of the appropriate request adapter
65+
or the original `llm_spec` if no adaptation is needed.
66+
"""
4367
match llm_spec.modality:
4468
case Modality.IMAGE:
4569
return image_generator.RequestAdapter(llm_spec)
@@ -59,17 +83,22 @@ async def process_prompt(
5983
fuzzer_state: FuzzerState,
6084
) -> tuple[int, bool]:
6185
"""
62-
Process a single prompt and update the token count and failure status.
86+
Processes a single prompt using the provided request factory and updates tracking lists.
87+
88+
This function sends the given `prompt` to the `request_factory`, checks for errors, and updates
89+
the `tokens`, `refusals`, `errors`, and `outputs` lists accordingly. If the request fails or
90+
the response indicates a refusal, the function records the issue and returns the updated token count
91+
along with a boolean indicating whether the prompt was refused.
6392
6493
Args:
65-
request_factory: The factory for creating requests
66-
prompt: The prompt to process
67-
tokens: Current token count
68-
module_name: Name of the module being processed
94+
request_factory: An object with a `fn` method used to send the prompt.
95+
prompt (str): The input prompt to be processed.
96+
tokens (int): The current token count, which will be updated.
97+
module_name (str): The name of the module handling the request.
6998
fuzzer_state: State tracking object for the fuzzer
7099
71100
Returns:
72-
Tuple of (updated token count, whether the prompt resulted in a failure)
101+
tuple[int, bool]: Updated token count and a boolean indicating if the prompt was refused.
73102
"""
74103
try:
75104
response = await request_factory.fn(prompt=prompt)
@@ -122,17 +151,23 @@ async def process_prompt_batch(
122151
fuzzer_state: FuzzerState,
123152
) -> tuple[int, int]:
124153
"""
125-
Process a batch of prompts in parallel.
154+
Processes a batch of prompts asynchronously and aggregates the results.
155+
156+
This function sends multiple prompts concurrently using `process_prompt`,
157+
collects the token count and failure status for each prompt, and returns
158+
the total number of tokens processed and the number of failed prompts.
126159
127160
Args:
128-
request_factory: The factory for creating requests
129-
prompts: List of prompts to process
130-
tokens: Current token count
131-
module_name: Name of the module being processed
161+
request_factory: An object with a `fn` method used to send the prompts.
162+
prompts (list[str]): A list of input prompts to be processed.
163+
tokens (int): The initial token count, which will be updated.
164+
module_name (str): The name of the module handling the request.
132165
fuzzer_state: State tracking object for the fuzzer
133166
134167
Returns:
135-
Tuple of (updated token count, number of failures)
168+
tuple[int, int]:
169+
- Total number of tokens processed.
170+
- Number of failed prompts.
136171
"""
137172
tasks = [
138173
process_prompt(request_factory, p, tokens, module_name, fuzzer_state)
@@ -268,7 +303,20 @@ async def scan_module(
268303

269304

270305
async def with_error_handling(agen):
271-
"""Wrapper to handle errors in async generators."""
306+
"""
307+
Wraps an asynchronous generator with error handling.
308+
309+
This function iterates over an asynchronous generator, yielding its values.
310+
If an exception occurs, it logs the error and yields a failure message.
311+
Finally, it ensures that a completion message is always yielded.
312+
313+
Args:
314+
agen: An asynchronous generator that produces scan results.
315+
316+
Yields:
317+
ScanResult: Either a successful result, an error message if an
318+
exception occurs, or a completion message at the end.
319+
"""
272320
try:
273321
async for t in agen:
274322
yield t
@@ -289,19 +337,27 @@ async def perform_single_shot_scan(
289337
secrets: dict[str, str] = {},
290338
) -> AsyncGenerator[str, None]:
291339
"""
292-
Perform a standard security scan across all selected datasets.
340+
Perform a standard security scan using a given request factory.
341+
342+
This function processes security scan prompts from selected datasets while
343+
respecting a predefined token budget. It supports optimization, failure tracking,
344+
and early stopping based on budget constraints or user intervention.
293345
294346
Args:
295-
request_factory: The factory for creating requests
296-
max_budget: Maximum token budget
297-
datasets: List of datasets to scan
298-
tools_inbox: Tools inbox
299-
optimize: Whether to use optimization
300-
stop_event: Event to stop scanning
301-
secrets: Secrets to use in the scan
347+
request_factory: A factory function that generates requests for processing prompts.
348+
max_budget (int): The maximum token budget for the scan.
349+
datasets (list[dict[str, str]], optional): A list of datasets containing security prompts.
350+
tools_inbox: Optional additional tools for processing (default: None).
351+
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
352+
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
353+
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
302354
303355
Yields:
304-
ScanResult objects as the scan progresses
356+
str: JSON-encoded scan results or status messages.
357+
358+
The function iterates over prompts, processes them asynchronously, and updates
359+
failure statistics and token usage. If the scan exceeds the budget or failure rate is too high,
360+
it stops execution. Results are saved to a CSV file upon completion.
305361
"""
306362
max_budget = max_budget * BUDGET_MULTIPLIER
307363
selected_datasets = [m for m in datasets if m["selected"]]
@@ -362,23 +418,30 @@ async def perform_many_shot_scan(
362418
"""
363419
Perform a multi-step security scan with probe injection.
364420
421+
This function executes a security scan while periodically injecting probe datasets
422+
to test system robustness. It tracks failures, optimizes scan efficiency,
423+
and ensures adherence to a predefined token budget.
424+
365425
Args:
366-
request_factory: The factory for creating requests
367-
max_budget: Maximum token budget
368-
datasets: List of datasets to scan
369-
probe_datasets: List of probe datasets to inject
370-
tools_inbox: Tools inbox
371-
optimize: Whether to use optimization
372-
stop_event: Event to stop scanning
373-
probe_frequency: Frequency of probe injection
374-
max_ctx_length: Maximum context length
375-
secrets: Secrets to use in the scan
426+
request_factory: A factory function that generates requests for processing prompts.
427+
max_budget (int): The maximum token budget for the scan.
428+
datasets (list[dict[str, str]], optional): The main datasets for scanning.
429+
probe_datasets (list[dict[str, str]], optional): Additional datasets for probe injection.
430+
tools_inbox: Optional tools for additional processing (default: None).
431+
optimize (bool, optional): Whether to enable failure rate optimization (default: False).
432+
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
433+
probe_frequency (float, optional): The probability of probe injection (default: 0.2).
434+
max_ctx_length (int, optional): The maximum context length before resetting (default: 10,000 tokens).
435+
secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
376436
377437
Yields:
378-
ScanResult objects as the scan progresses
438+
str: JSON-encoded scan results or status messages.
439+
440+
This function iterates over prompts, injects probe prompts at random intervals,
441+
processes them asynchronously, and tracks failure rates. If failure rates exceed a threshold
442+
or budget is exhausted, the scan is stopped early. Results are saved to a CSV file upon completion.
379443
"""
380444
request_factory = get_modality_adapter(request_factory)
381-
382445
# Load main and probe datasets
383446
yield ScanResult.status_msg("Loading datasets...")
384447
prompt_modules = prepare_prompts(
@@ -473,16 +536,30 @@ def scan_router(
473536
stop_event: asyncio.Event | None = None,
474537
):
475538
"""
476-
Route to the appropriate scan function based on scan parameters.
539+
Route scan requests to the appropriate scanning function.
540+
541+
This function determines whether to perform a multi-step or single-shot
542+
security scan based on the provided scan parameters.
477543
478544
Args:
479-
request_factory: The factory for creating requests
480-
scan_parameters: Scan parameters
481-
tools_inbox: Tools inbox
482-
stop_event: Event to stop scanning
545+
request_factory: A factory function to generate requests for processing prompts.
546+
scan_parameters (Scan): An object containing the parameters for the scan, including:
547+
- enableMultiStepAttack (bool): Whether to perform a multi-step scan.
548+
- maxBudget (int): The maximum token budget for the scan.
549+
- datasets (list[dict[str, str]]): The datasets to scan.
550+
- probe_datasets (list[dict[str, str]], optional): Datasets for probe injection (multi-step only).
551+
- optimize (bool): Whether to enable optimization.
552+
- secrets (dict[str, str], optional): A dictionary of secrets for authentication.
553+
tools_inbox: Optional tools for additional processing (default: None).
554+
stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
483555
484556
Returns:
485-
Async generator of scan results
557+
A function wrapped with `with_error_handling`, which executes either:
558+
- `perform_many_shot_scan` for multi-step scanning.
559+
- `perform_single_shot_scan` for single-shot scanning.
560+
561+
The function ensures that the appropriate scanning method is chosen based on
562+
the `enableMultiStepAttack` flag in `scan_parameters`.
486563
"""
487564
if scan_parameters.enableMultiStepAttack:
488565
return with_error_handling(

0 commit comments

Comments
 (0)