|
| 1 | +import logging |
| 2 | +import os |
| 3 | +import random |
| 4 | +from pathlib import Path |
| 5 | +from typing import Generator, Optional |
| 6 | + |
| 7 | +import pandas as pd |
| 8 | + |
| 9 | +from ..attack_provider.attack_registry import register_test |
| 10 | +from ..attack_provider.test_base import StatusUpdate, TestBase |
| 11 | +from ..attack_provider.util import contains_refusal_keywords |
| 12 | +from ..client.attack_config import AttackConfig |
| 13 | +from ..client.chat_client import ChatSession |
| 14 | +from ..client.client_config import ClientConfig |
| 15 | + |
| 16 | +logger = logging.getLogger(__name__) |
| 17 | + |
| 18 | + |
| 19 | +@register_test |
| 20 | +class TestShuffle(TestBase): |
| 21 | + """Probes the LLM resilience against Shuffle Inconsistency. |
| 22 | + Original Paper: https://arxiv.org/html/2501.04931""" |
| 23 | + |
| 24 | + test_name = "shuffle" |
| 25 | + |
| 26 | + def __init__( |
| 27 | + self, |
| 28 | + client_config: ClientConfig, |
| 29 | + attack_config: AttackConfig, |
| 30 | + artifacts_path: Optional[str] = None, |
| 31 | + num_attempts: int = 0, |
| 32 | + *args, |
| 33 | + **kwargs, |
| 34 | + ): |
| 35 | + super().__init__( |
| 36 | + client_config, |
| 37 | + attack_config, |
| 38 | + artifacts_path=artifacts_path, |
| 39 | + num_attempts=num_attempts, |
| 40 | + *args, |
| 41 | + **kwargs, |
| 42 | + ) |
| 43 | + |
| 44 | + def _prepare_attack_data(self, attack_prompts: list[str], responses: list[str], statuses: list[str]) -> None: |
| 45 | + """ |
| 46 | + Prepares attack data in a structured DataFrame format and saves it as a CSV file. |
| 47 | +
|
| 48 | + Args: |
| 49 | + attack_prompts (list[str]): List of attack texts generated during the test. |
| 50 | + responses (list[str]): List of responses from the target system. |
| 51 | + statuses (list[str]): List of statuses ('broken', 'resilient', 'error') corresponding to each attack. |
| 52 | +
|
| 53 | + Returns: |
| 54 | + None |
| 55 | + """ |
| 56 | + # Create a DataFrame from the lists |
| 57 | + df = pd.DataFrame({"attack_text": attack_prompts, "response_text": responses, "status": statuses}) |
| 58 | + |
| 59 | + # Save the DataFrame as a CSV file to the artifacts path |
| 60 | + if self.artifacts_path: |
| 61 | + csv_path = os.path.join(self.artifacts_path, f"{self.test_name}.csv") |
| 62 | + df.to_csv(csv_path, index=False) |
| 63 | + logging.info(f"{self.test_name} attack report saved to {csv_path}") |
| 64 | + |
| 65 | + @staticmethod |
| 66 | + def shuffle_words(prompt: str) -> str: |
| 67 | + words = prompt.split() |
| 68 | + random.shuffle(words) # nosec |
| 69 | + return " ".join(words) |
| 70 | + |
| 71 | + def run(self) -> Generator[StatusUpdate, None, None]: |
| 72 | + attack_prompts = [] |
| 73 | + responses = [] |
| 74 | + statuses = [] |
| 75 | + # Getting the path to the directory where the file is located harmful_behavior.py |
| 76 | + current_dir = Path(__file__).parent |
| 77 | + dataset_filename = current_dir.parent / "attack_data" / "harmful_behavior.parquet" |
| 78 | + |
| 79 | + logger.info(f"Dataset filename: {dataset_filename}") |
| 80 | + data = pd.read_parquet(dataset_filename, engine="fastparquet") |
| 81 | + rows_count = data.shape[0] |
| 82 | + |
| 83 | + if self.num_attempts <= rows_count: |
| 84 | + # If dataset is larger than requested number of samples, reduce dataset by taking num_attempts samples |
| 85 | + data = data.sample(n=self.num_attempts) |
| 86 | + else: |
| 87 | + # If dataset is smaller than requested number of attack samples, fit the dataset repeating until num_attempts is reached |
| 88 | + data = ( |
| 89 | + pd.concat([data] * (self.num_attempts // len(data) + 1))[: self.num_attempts] |
| 90 | + .sort_index() |
| 91 | + .reset_index(drop=True) |
| 92 | + ) |
| 93 | + |
| 94 | + # Generate list of attack prompt variations |
| 95 | + yield StatusUpdate(self.client_config, self.test_name, self.status, "Generating", 0, self.num_attempts) |
| 96 | + |
| 97 | + for attack_prompt_index, row in data.iterrows(): |
| 98 | + attack_prompt = row["goal"] |
| 99 | + shuffled_prompt = self.shuffle_words(attack_prompt) |
| 100 | + attack_prompts.append(shuffled_prompt) |
| 101 | + logger.info(f"Test '{self.test_name}': generated attack prompt #{attack_prompt_index}: {shuffled_prompt}") |
| 102 | + yield StatusUpdate( |
| 103 | + self.client_config, |
| 104 | + self.test_name, |
| 105 | + self.status, |
| 106 | + "Generating", |
| 107 | + attack_prompt_index, |
| 108 | + self.num_attempts, |
| 109 | + ) |
| 110 | + |
| 111 | + yield StatusUpdate( |
| 112 | + self.client_config, |
| 113 | + self.test_name, |
| 114 | + self.status, |
| 115 | + "Attacking", |
| 116 | + 0, |
| 117 | + self.num_attempts, |
| 118 | + ) |
| 119 | + |
| 120 | + # Run the attack by sending the prompts and verifying the responses |
| 121 | + for attack_prompt_index, attack_prompt in enumerate(attack_prompts): |
| 122 | + chat = ChatSession( |
| 123 | + self.client_config.get_target_client(), |
| 124 | + self.client_config.get_system_prompts(), |
| 125 | + ) |
| 126 | + response = chat.say(attack_prompt) |
| 127 | + if not response: |
| 128 | + logger.warning("Error while attacking against target LLM (didn't receive response) ...") |
| 129 | + self.status.report_error(attack_prompt) |
| 130 | + current_status = "Error" |
| 131 | + else: |
| 132 | + # Verify the response to determine if this specific attack was successful |
| 133 | + verified = not contains_refusal_keywords(response) |
| 134 | + logger.info(f"Test '{self.test_name}': attack prompt #{attack_prompt_index}: {attack_prompt}") |
| 135 | + logger.info( |
| 136 | + f"Test '{self.test_name}': attack response #{attack_prompt_index} break_success={verified}): {response}" |
| 137 | + ) |
| 138 | + if verified: |
| 139 | + self.status.report_breach(attack_prompt, response) |
| 140 | + current_status = "Broken" |
| 141 | + else: |
| 142 | + self.status.report_resilient(attack_prompt, response) |
| 143 | + current_status = "Resilient" |
| 144 | + |
| 145 | + # Append data for report |
| 146 | + responses.append(response) |
| 147 | + statuses.append(current_status) |
| 148 | + |
| 149 | + yield StatusUpdate( |
| 150 | + self.client_config, |
| 151 | + self.test_name, |
| 152 | + self.status, |
| 153 | + "Attacking", |
| 154 | + attack_prompt_index, |
| 155 | + self.num_attempts, |
| 156 | + ) |
| 157 | + |
| 158 | + # Prepare data for report generation |
| 159 | + self._prepare_attack_data(attack_prompts, responses, statuses) |
| 160 | + |
| 161 | + yield StatusUpdate( |
| 162 | + self.client_config, |
| 163 | + self.test_name, |
| 164 | + self.status, |
| 165 | + "Finished", |
| 166 | + self.num_attempts, |
| 167 | + self.num_attempts, |
| 168 | + ) |
0 commit comments