2929async def generate_prompts (
3030 prompts : list [str ] | AsyncGenerator ,
3131) -> AsyncGenerator [str , None ]:
32+ """
33+ Asynchronously generates and yields individual prompts.
34+
35+ If the input is a list of strings, the function sequentially yields each string.
36+ If the input is an asynchronous generator, it forwards each generated prompt.
37+
38+ Args:
39+ prompts (list[str] | AsyncGenerator): A list of strings or an asynchronous generator of prompts.
40+
41+ Yields:
42+ str: An individual prompt from the list or the asynchronous generator.
43+ """
3244 if isinstance (prompts , list ):
3345 for prompt in prompts :
3446 yield prompt
@@ -38,6 +50,20 @@ async def generate_prompts(
3850
3951
4052def multi_modality_spec (llm_spec ):
53+ """
54+ Returns the appropriate request adapter based on the modality of the LLM specification.
55+
56+ Depending on the modality of `llm_spec`, the function selects the corresponding request adapter.
57+ If the modality is IMAGE or AUDIO, it returns an adapter for handling the respective type.
58+ If the modality is TEXT or an unrecognized type, it returns `llm_spec` as is.
59+
60+ Args:
61+ llm_spec: An object containing modality information for the LLM.
62+
63+ Returns:
64+ RequestAdapter | llm_spec: An instance of the appropriate request adapter
65+ or the original `llm_spec` if no adaptation is needed.
66+ """
4167 match llm_spec .modality :
4268 case Modality .IMAGE :
4369 return image_generator .RequestAdapter (llm_spec )
@@ -53,7 +79,24 @@ async def process_prompt(
5379 request_factory , prompt , tokens , module_name , refusals , errors , outputs
5480) -> tuple [int , bool ]:
5581 """
56- Process a single prompt and update the token count and failure status.
82+ Processes a single prompt using the provided request factory and updates tracking lists.
83+
84+ This function sends the given `prompt` to the `request_factory`, checks for errors, and updates
85+ the `tokens`, `refusals`, `errors`, and `outputs` lists accordingly. If the request fails or
86+ the response indicates a refusal, the function records the issue and returns the updated token count
87+ along with a boolean indicating whether the prompt was refused.
88+
89+ Args:
90+ request_factory: An object with a `fn` method used to send the prompt.
91+ prompt (str): The input prompt to be processed.
92+ tokens (int): The current token count, which will be updated.
93+ module_name (str): The name of the module handling the request.
94+ refusals (list): A list to store prompts that were refused.
95+ errors (list): A list to store prompts that encountered errors.
96+ outputs (list): A list to store processed prompt outputs.
97+
98+ Returns:
99+ tuple[int, bool]: Updated token count and a boolean indicating if the prompt was refused.
57100 """
58101 try :
59102 response = await request_factory .fn (prompt = prompt )
@@ -95,6 +138,27 @@ async def process_prompt_batch(
95138 errors ,
96139 outputs ,
97140) -> tuple [int , int ]:
141+ """
142+ Processes a batch of prompts asynchronously and aggregates the results.
143+
144+ This function sends multiple prompts concurrently using `process_prompt`,
145+ collects the token count and failure status for each prompt, and returns
146+ the total number of tokens processed and the number of failed prompts.
147+
148+ Args:
149+ request_factory: An object with a `fn` method used to send the prompts.
150+ prompts (list[str]): A list of input prompts to be processed.
151+ tokens (int): The initial token count, which will be updated.
152+ module_name (str): The name of the module handling the request.
153+ refusals (list): A list to store prompts that were refused.
154+ errors (list): A list to store prompts that encountered errors.
155+ outputs (list): A list to store processed prompt outputs.
156+
157+ Returns:
158+ tuple[int, int]:
159+ - Total number of tokens processed.
160+ - Number of failed prompts.
161+ """
98162 tasks = [
99163 process_prompt (
100164 request_factory , p , tokens , module_name , refusals , errors , outputs
@@ -108,6 +172,20 @@ async def process_prompt_batch(
108172
109173
110174async def with_error_handling (agen ):
175+ """
176+ Wraps an asynchronous generator with error handling.
177+
178+ This function iterates over an asynchronous generator, yielding its values.
179+ If an exception occurs, it logs the error and yields a failure message.
180+ Finally, it ensures that a completion message is always yielded.
181+
182+ Args:
183+ agen: An asynchronous generator that produces scan results.
184+
185+ Yields:
186+ ScanResult: Either a successful result, an error message if an
187+ exception occurs, or a completion message at the end.
188+ """
111189 try :
112190 async for t in agen :
113191 yield t
@@ -127,7 +205,29 @@ async def perform_single_shot_scan(
127205 stop_event : asyncio .Event = None ,
128206 secrets : dict [str , str ] = {},
129207) -> AsyncGenerator [str , None ]:
130- """Perform a standard security scan."""
208+ """
209+ Perform a standard security scan using a given request factory.
210+
211+ This function processes security scan prompts from selected datasets while
212+ respecting a predefined token budget. It supports optimization, failure tracking,
213+ and early stopping based on budget constraints or user intervention.
214+
215+ Args:
216+ request_factory: A factory function that generates requests for processing prompts.
217+ max_budget (int): The maximum token budget for the scan.
218+ datasets (list[dict[str, str]], optional): A list of datasets containing security prompts.
219+ tools_inbox: Optional additional tools for processing (default: None).
220+ optimize (bool, optional): Whether to enable failure rate optimization (default: False).
221+ stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
222+ secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
223+
224+ Yields:
225+ str: JSON-encoded scan results or status messages.
226+
227+ The function iterates over prompts, processes them asynchronously, and updates
228+ failure statistics and token usage. If the scan exceeds the budget or failure rate is too high,
229+ it stops execution. Results are saved to a CSV file upon completion.
230+ """
131231 max_budget = max_budget * BUDGET_MULTIPLIER
132232 selected_datasets = [m for m in datasets if m ["selected" ]]
133233 request_factory = multi_modality_spec (request_factory )
@@ -256,7 +356,32 @@ async def perform_many_shot_scan(
256356 max_ctx_length : int = 10_000 ,
257357 secrets : dict [str , str ] = {},
258358) -> AsyncGenerator [str , None ]:
259- """Perform a multi-step security scan with probe injection."""
359+ """
360+ Perform a multi-step security scan with probe injection.
361+
362+ This function executes a security scan while periodically injecting probe datasets
363+ to test system robustness. It tracks failures, optimizes scan efficiency,
364+ and ensures adherence to a predefined token budget.
365+
366+ Args:
367+ request_factory: A factory function that generates requests for processing prompts.
368+ max_budget (int): The maximum token budget for the scan.
369+ datasets (list[dict[str, str]], optional): The main datasets for scanning.
370+ probe_datasets (list[dict[str, str]], optional): Additional datasets for probe injection.
371+ tools_inbox: Optional tools for additional processing (default: None).
372+ optimize (bool, optional): Whether to enable failure rate optimization (default: False).
373+ stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
374+ probe_frequency (float, optional): The probability of probe injection (default: 0.2).
375+ max_ctx_length (int, optional): The maximum context length before resetting (default: 10,000 tokens).
376+ secrets (dict[str, str], optional): A dictionary of secrets for authentication (default: {}).
377+
378+ Yields:
379+ str: JSON-encoded scan results or status messages.
380+
381+ This function iterates over prompts, injects probe prompts at random intervals,
382+ processes them asynchronously, and tracks failure rates. If failure rates exceed a threshold
383+ or budget is exhausted, the scan is stopped early. Results are saved to a CSV file upon completion.
384+ """
260385 request_factory = multi_modality_spec (request_factory )
261386 # Load main and probe datasets
262387 yield ScanResult .status_msg ("Loading datasets..." )
@@ -367,6 +492,32 @@ def scan_router(
367492 tools_inbox = None ,
368493 stop_event : asyncio .Event = None ,
369494):
495+ """
496+ Route scan requests to the appropriate scanning function.
497+
498+ This function determines whether to perform a multi-step or single-shot
499+ security scan based on the provided scan parameters.
500+
501+ Args:
502+ request_factory: A factory function to generate requests for processing prompts.
503+ scan_parameters (Scan): An object containing the parameters for the scan, including:
504+ - enableMultiStepAttack (bool): Whether to perform a multi-step scan.
505+ - maxBudget (int): The maximum token budget for the scan.
506+ - datasets (list[dict[str, str]]): The datasets to scan.
507+ - probe_datasets (list[dict[str, str]], optional): Datasets for probe injection (multi-step only).
508+ - optimize (bool): Whether to enable optimization.
509+ - secrets (dict[str, str], optional): A dictionary of secrets for authentication.
510+ tools_inbox: Optional tools for additional processing (default: None).
511+ stop_event (asyncio.Event, optional): An event to signal early termination (default: None).
512+
513+ Returns:
514+ A function wrapped with `with_error_handling`, which executes either:
515+ - `perform_many_shot_scan` for multi-step scanning.
516+ - `perform_single_shot_scan` for single-shot scanning.
517+
518+ The function ensures that the appropriate scanning method is chosen based on
519+ the `enableMultiStepAttack` flag in `scan_parameters`.
520+ """
370521 if scan_parameters .enableMultiStepAttack :
371522 return with_error_handling (
372523 perform_many_shot_scan (
0 commit comments