|
| 1 | +# Copyright (c) Microsoft Corporation. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +import logging |
| 5 | +import uuid |
| 6 | +from enum import Enum |
| 7 | +from typing import List, Literal, Optional |
| 8 | + |
| 9 | +from pyrit.common.net_utility import make_request_and_raise_if_error_async |
| 10 | +from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( |
| 11 | + _RemoteDatasetLoader, |
| 12 | +) |
| 13 | +from pyrit.models import SeedDataset, SeedPrompt, data_serializer_factory |
| 14 | + |
| 15 | +logger = logging.getLogger(__name__) |
| 16 | + |
| 17 | + |
| 18 | +class VLSUCategory(Enum): |
| 19 | + """Categories in the ML-VLSU dataset.""" |
| 20 | + |
| 21 | + SLURS_HATE_SPEECH = "C1: Slurs, Hate Speech, Hate Symbols" |
| 22 | + DISCRIMINATION = "C2: Discrimination and Unequal Treatment" |
| 23 | + DRUG_ABUSE = "C3: Drug Abuse" |
| 24 | + SELF_HARM = "C4: Self-Harm and Suicide" |
| 25 | + ANIMAL_VIOLENCE = "C5: Animal Violence and Gore" |
| 26 | + ADULT_EXPLICIT = "C6: Adult Explicit Sexual Material" |
| 27 | + ADULT_RACY = "C7: Adult Racy Material" |
| 28 | + WARFARE = "C8: Warfare and Armed Conflicts" |
| 29 | + INTERPERSONAL_VIOLENCE = "C9: Interpersonal Violence" |
| 30 | + WEAPONS = "C10: Weapons and Dangerous Objects" |
| 31 | + GORE_GRAPHIC = "C11: Gore and Graphic Content" |
| 32 | + TERRORISM = "C12: Terrorism and Violent Extremism" |
| 33 | + JAILBREAKS = "C13: Jailbreaks" |
| 34 | + INAUTHENTIC_PRACTICES = "C14: Inauthentic Practices/Fraud" |
| 35 | + HUMAN_EXPLOITATION = "C15: Human Exploitation" |
| 36 | + |
| 37 | + |
| 38 | +class _VLSUMultimodalDataset(_RemoteDatasetLoader): |
| 39 | + """ |
| 40 | + Loader for the ML-VLSU multimodal dataset. |
| 41 | +
|
| 42 | + The ML-VLSU dataset contains image-text pairs with human annotations across three dimensions: |
| 43 | + image-only, text-only, and joint safety. This dataset is specifically designed to test |
| 44 | + emergent multimodal safety issues where harm arises from the combination of safe text |
| 45 | + and safe images. |
| 46 | +
|
| 47 | + Note: The first call may be slow as images need to be downloaded from remote URLs. |
| 48 | + Subsequent calls will be faster since images are cached locally. |
| 49 | +
|
| 50 | + Reference: https://arxiv.org/abs/2501.01151 |
| 51 | + """ |
| 52 | + |
| 53 | + def __init__( |
| 54 | + self, |
| 55 | + *, |
| 56 | + source: str = "https://raw.githubusercontent.com/apple/ml-vlsu/main/data/VLSU.csv", |
| 57 | + source_type: Literal["public_url", "file"] = "public_url", |
| 58 | + categories: Optional[List[VLSUCategory]] = None, |
| 59 | + unsafe_grades: Optional[List[str]] = ["unsafe", "borderline"], |
| 60 | + ): |
| 61 | + """ |
| 62 | + Initialize the ML-VLSU multimodal dataset loader. |
| 63 | +
|
| 64 | + Args: |
| 65 | + source: URL or file path to the VLSU CSV file. Defaults to official repository. |
| 66 | + source_type: The type of source ('public_url' or 'file'). |
| 67 | + categories: List of VLSU categories to filter examples. |
| 68 | + If None, all categories are included (default). |
| 69 | + unsafe_grades: List of grades considered unsafe (e.g., ['unsafe', 'borderline']). |
| 70 | + Prompts are created only when the respective grade matches one of these values. |
| 71 | + Defaults to ['unsafe', 'borderline']. Possible options further include 'safe' and 'not_sure'. |
| 72 | +
|
| 73 | + Raises: |
| 74 | + ValueError: If any of the specified categories are invalid. |
| 75 | + """ |
| 76 | + self.source = source |
| 77 | + self.source_type: Literal["public_url", "file"] = source_type |
| 78 | + self.categories = categories |
| 79 | + self.unsafe_grades = unsafe_grades |
| 80 | + |
| 81 | + # Validate categories if provided |
| 82 | + if categories is not None: |
| 83 | + valid_categories = {category.value for category in VLSUCategory} |
| 84 | + invalid_categories = ( |
| 85 | + set(cat.value if isinstance(cat, VLSUCategory) else cat for cat in categories) - valid_categories |
| 86 | + ) |
| 87 | + if invalid_categories: |
| 88 | + raise ValueError(f"Invalid VLSU categories: {', '.join(invalid_categories)}") |
| 89 | + |
| 90 | + @property |
| 91 | + def dataset_name(self) -> str: |
| 92 | + """Return the dataset name.""" |
| 93 | + return "ml_vlsu" |
| 94 | + |
| 95 | + async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: |
| 96 | + """ |
| 97 | + Fetch ML-VLSU multimodal examples and return as SeedDataset. |
| 98 | +
|
| 99 | + Creates prompts based on safety grades: |
| 100 | + - Text prompt: if text_grade is 'unsafe' or 'borderline' |
| 101 | + - Image prompt: if image_grade is 'unsafe' or 'borderline' |
| 102 | + - Combined prompt: if combined_grade is 'unsafe' or 'borderline' |
| 103 | +
|
| 104 | + Args: |
| 105 | + cache: Whether to cache the fetched dataset. Defaults to True. |
| 106 | +
|
| 107 | + Returns: |
| 108 | + SeedDataset: A SeedDataset containing the multimodal examples. |
| 109 | +
|
| 110 | + Raises: |
| 111 | + ValueError: If any example is missing required keys. |
| 112 | + """ |
| 113 | + logger.info(f"Loading ML-VLSU dataset from {self.source}") |
| 114 | + |
| 115 | + required_keys = { |
| 116 | + "prompt", |
| 117 | + "web_path", |
| 118 | + "combined_category", |
| 119 | + "uuid", |
| 120 | + "image_grade", |
| 121 | + "consensus_text_grade", |
| 122 | + "consensus_combined_grade", |
| 123 | + "image_category", |
| 124 | + "text_category", |
| 125 | + } |
| 126 | + examples = self._fetch_from_url( |
| 127 | + source=self.source, |
| 128 | + source_type=self.source_type, |
| 129 | + cache=cache, |
| 130 | + ) |
| 131 | + |
| 132 | + prompts = [] |
| 133 | + failed_image_count = 0 |
| 134 | + |
| 135 | + for example in examples: |
| 136 | + missing_keys = required_keys - example.keys() |
| 137 | + if missing_keys: |
| 138 | + raise ValueError(f"Missing keys in example: {', '.join(missing_keys)}") |
| 139 | + |
| 140 | + text = example.get("prompt") |
| 141 | + image_url = example.get("web_path") |
| 142 | + |
| 143 | + # Get grades |
| 144 | + text_grade = example.get("consensus_text_grade", "").lower() |
| 145 | + image_grade = example.get("image_grade", "").lower() |
| 146 | + combined_grade = example.get("consensus_combined_grade", "").lower() |
| 147 | + |
| 148 | + # Get categories |
| 149 | + text_category = example.get("text_category") |
| 150 | + image_category = example.get("image_category") |
| 151 | + combined_category = example.get("combined_category") |
| 152 | + |
| 153 | + # Filter by categories if specified |
| 154 | + if self.categories is not None: |
| 155 | + category_values = {cat.value for cat in self.categories} |
| 156 | + if combined_category not in category_values: |
| 157 | + continue |
| 158 | + |
| 159 | + # Only create prompts if combined_grade is unsafe or borderline |
| 160 | + if combined_grade not in self.unsafe_grades: |
| 161 | + continue |
| 162 | + |
| 163 | + # Generate a shared group_id for the text+image pair |
| 164 | + group_id = uuid.uuid4() |
| 165 | + |
| 166 | + try: |
| 167 | + local_image_path = await self._fetch_and_save_image_async(image_url, str(group_id)) |
| 168 | + |
| 169 | + # Create text prompt (sequence=0, sent first) |
| 170 | + text_prompt = SeedPrompt( |
| 171 | + value=text, |
| 172 | + data_type="text", |
| 173 | + name="ML-VLSU Text", |
| 174 | + dataset_name=self.dataset_name, |
| 175 | + harm_categories=[combined_category], |
| 176 | + description="Text component of ML-VLSU multimodal prompt.", |
| 177 | + source=self.source, |
| 178 | + prompt_group_id=group_id, |
| 179 | + sequence=0, |
| 180 | + metadata={ |
| 181 | + "category": combined_category, |
| 182 | + "text_grade": text_grade, |
| 183 | + "image_grade": image_grade, |
| 184 | + "combined_grade": combined_grade, |
| 185 | + }, |
| 186 | + ) |
| 187 | + |
| 188 | + # Create image prompt (sequence=1, sent second) |
| 189 | + image_prompt = SeedPrompt( |
| 190 | + value=local_image_path, |
| 191 | + data_type="image_path", |
| 192 | + name="ML-VLSU Image", |
| 193 | + dataset_name=self.dataset_name, |
| 194 | + harm_categories=[combined_category], |
| 195 | + description="Image component of ML-VLSU multimodal prompt.", |
| 196 | + source=self.source, |
| 197 | + prompt_group_id=group_id, |
| 198 | + sequence=1, |
| 199 | + metadata={ |
| 200 | + "category": combined_category, |
| 201 | + "text_grade": text_grade, |
| 202 | + "image_grade": image_grade, |
| 203 | + "combined_grade": combined_grade, |
| 204 | + "original_image_url": image_url, |
| 205 | + }, |
| 206 | + ) |
| 207 | + |
| 208 | + prompts.append(text_prompt) |
| 209 | + prompts.append(image_prompt) |
| 210 | + |
| 211 | + except Exception as e: |
| 212 | + failed_image_count += 1 |
| 213 | + logger.warning(f"Failed to fetch image for combined prompt {group_id}: {e}") |
| 214 | + |
| 215 | + if failed_image_count > 0: |
| 216 | + logger.warning(f"[ML-VLSU] Skipped {failed_image_count} image(s) due to fetch failures") |
| 217 | + |
| 218 | + logger.info(f"Successfully loaded {len(prompts)} prompts from ML-VLSU dataset") |
| 219 | + |
| 220 | + return SeedDataset(seeds=prompts, dataset_name=self.dataset_name) |
| 221 | + |
| 222 | + async def _fetch_and_save_image_async(self, image_url: str, group_id: str) -> str: |
| 223 | + """ |
| 224 | + Fetch and save an image from the ML-VLSU dataset. |
| 225 | +
|
| 226 | + Args: |
| 227 | + image_url: URL to the image. |
| 228 | + group_id: Group ID for naming the cached file. |
| 229 | +
|
| 230 | + Returns: |
| 231 | + Local path to the saved image. |
| 232 | + """ |
| 233 | + filename = f"ml_vlsu_{group_id}.png" |
| 234 | + serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png") |
| 235 | + |
| 236 | + # Return existing path if image already exists |
| 237 | + serializer.value = str(serializer._memory.results_path + serializer.data_sub_directory + f"/{filename}") |
| 238 | + try: |
| 239 | + if await serializer._memory.results_storage_io.path_exists(serializer.value): |
| 240 | + return serializer.value |
| 241 | + except Exception as e: |
| 242 | + logger.warning(f"[ML-VLSU] Failed to check if image for {group_id} exists in cache: {e}") |
| 243 | + |
| 244 | + # Add browser-like headers for better success rate |
| 245 | + headers = { |
| 246 | + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", |
| 247 | + "Accept": "image/webp,image/apng,image/*,*/*;q=0.8", |
| 248 | + "Accept-Language": "en-US,en;q=0.9", |
| 249 | + "Accept-Encoding": "gzip, deflate, br", |
| 250 | + "DNT": "1", |
| 251 | + "Connection": "keep-alive", |
| 252 | + "Upgrade-Insecure-Requests": "1", |
| 253 | + } |
| 254 | + |
| 255 | + response = await make_request_and_raise_if_error_async( |
| 256 | + endpoint_uri=image_url, |
| 257 | + method="GET", |
| 258 | + headers=headers, |
| 259 | + timeout=2.0, |
| 260 | + follow_redirects=True, |
| 261 | + ) |
| 262 | + await serializer.save_data(data=response.content, output_filename=filename.replace(".png", "")) |
| 263 | + |
| 264 | + return str(serializer.value) |
0 commit comments