Skip to content

Commit 68fa348

Browse files
riyosharlundeen2
andauthored
FEAT: Added VLSU Multimodal Dataset (#1309)
Co-authored-by: Richard Lundeen <[email protected]>
1 parent 7a7f8b6 commit 68fa348

File tree

3 files changed

+721
-20
lines changed

3 files changed

+721
-20
lines changed

pyrit/datasets/seed_datasets/remote/__init__.py

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,36 +7,99 @@
77
Import concrete implementations to trigger registration.
88
"""
99

10-
from pyrit.datasets.seed_datasets.remote.aegis_ai_content_safety_dataset import _AegisContentSafetyDataset # noqa: F401
11-
from pyrit.datasets.seed_datasets.remote.aya_redteaming_dataset import _AyaRedteamingDataset # noqa: F401
12-
from pyrit.datasets.seed_datasets.remote.babelscape_alert_dataset import _BabelscapeAlertDataset # noqa: F401
13-
from pyrit.datasets.seed_datasets.remote.ccp_sensitive_prompts_dataset import _CCPSensitivePromptsDataset # noqa: F401
14-
from pyrit.datasets.seed_datasets.remote.darkbench_dataset import _DarkBenchDataset # noqa: F401
15-
from pyrit.datasets.seed_datasets.remote.equitymedqa_dataset import _EquityMedQADataset # noqa: F401
16-
from pyrit.datasets.seed_datasets.remote.forbidden_questions_dataset import _ForbiddenQuestionsDataset # noqa: F401
17-
from pyrit.datasets.seed_datasets.remote.harmbench_dataset import _HarmBenchDataset # noqa: F401
18-
from pyrit.datasets.seed_datasets.remote.harmbench_multimodal_dataset import _HarmBenchMultimodalDataset # noqa: F401
19-
from pyrit.datasets.seed_datasets.remote.jbb_behaviors_dataset import _JBBBehaviorsDataset # noqa: F401
20-
from pyrit.datasets.seed_datasets.remote.librai_do_not_answer_dataset import _LibrAIDoNotAnswerDataset # noqa: F401
10+
from pyrit.datasets.seed_datasets.remote.aegis_ai_content_safety_dataset import (
11+
_AegisContentSafetyDataset,
12+
) # noqa: F401
13+
from pyrit.datasets.seed_datasets.remote.aya_redteaming_dataset import (
14+
_AyaRedteamingDataset,
15+
) # noqa: F401
16+
from pyrit.datasets.seed_datasets.remote.babelscape_alert_dataset import (
17+
_BabelscapeAlertDataset,
18+
) # noqa: F401
19+
from pyrit.datasets.seed_datasets.remote.ccp_sensitive_prompts_dataset import (
20+
_CCPSensitivePromptsDataset,
21+
) # noqa: F401
22+
from pyrit.datasets.seed_datasets.remote.darkbench_dataset import (
23+
_DarkBenchDataset,
24+
) # noqa: F401
25+
from pyrit.datasets.seed_datasets.remote.equitymedqa_dataset import (
26+
_EquityMedQADataset,
27+
) # noqa: F401
28+
from pyrit.datasets.seed_datasets.remote.forbidden_questions_dataset import (
29+
_ForbiddenQuestionsDataset,
30+
) # noqa: F401
31+
from pyrit.datasets.seed_datasets.remote.harmbench_dataset import (
32+
_HarmBenchDataset,
33+
) # noqa: F401
34+
from pyrit.datasets.seed_datasets.remote.harmbench_multimodal_dataset import (
35+
_HarmBenchMultimodalDataset,
36+
) # noqa: F401
37+
from pyrit.datasets.seed_datasets.remote.jbb_behaviors_dataset import (
38+
_JBBBehaviorsDataset,
39+
) # noqa: F401
40+
from pyrit.datasets.seed_datasets.remote.librai_do_not_answer_dataset import (
41+
_LibrAIDoNotAnswerDataset,
42+
) # noqa: F401
2143
from pyrit.datasets.seed_datasets.remote.llm_latent_adversarial_training_dataset import ( # noqa: F401
2244
_LLMLatentAdversarialTrainingDataset,
2345
)
24-
from pyrit.datasets.seed_datasets.remote.medsafetybench_dataset import _MedSafetyBenchDataset # noqa: F401
25-
from pyrit.datasets.seed_datasets.remote.mlcommons_ailuminate_dataset import _MLCommonsAILuminateDataset # noqa: F401
46+
from pyrit.datasets.seed_datasets.remote.medsafetybench_dataset import (
47+
_MedSafetyBenchDataset,
48+
) # noqa: F401
49+
from pyrit.datasets.seed_datasets.remote.mlcommons_ailuminate_dataset import (
50+
_MLCommonsAILuminateDataset,
51+
) # noqa: F401
2652
from pyrit.datasets.seed_datasets.remote.multilingual_vulnerability_dataset import ( # noqa: F401
2753
_MultilingualVulnerabilityDataset,
2854
)
29-
from pyrit.datasets.seed_datasets.remote.pku_safe_rlhf_dataset import _PKUSafeRLHFDataset # noqa: F401
30-
from pyrit.datasets.seed_datasets.remote.red_team_social_bias_dataset import _RedTeamSocialBiasDataset # noqa: F401
31-
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import _RemoteDatasetLoader
32-
from pyrit.datasets.seed_datasets.remote.sorry_bench_dataset import _SorryBenchDataset # noqa: F401
33-
from pyrit.datasets.seed_datasets.remote.sosbench_dataset import _SOSBenchDataset # noqa: F401
34-
from pyrit.datasets.seed_datasets.remote.tdc23_redteaming_dataset import _TDC23RedteamingDataset # noqa: F401
55+
from pyrit.datasets.seed_datasets.remote.pku_safe_rlhf_dataset import (
56+
_PKUSafeRLHFDataset,
57+
) # noqa: F401
58+
from pyrit.datasets.seed_datasets.remote.red_team_social_bias_dataset import (
59+
_RedTeamSocialBiasDataset,
60+
) # noqa: F401
61+
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
62+
_RemoteDatasetLoader,
63+
)
64+
from pyrit.datasets.seed_datasets.remote.sorry_bench_dataset import (
65+
_SorryBenchDataset,
66+
) # noqa: F401
67+
from pyrit.datasets.seed_datasets.remote.sosbench_dataset import (
68+
_SOSBenchDataset,
69+
) # noqa: F401
70+
from pyrit.datasets.seed_datasets.remote.tdc23_redteaming_dataset import (
71+
_TDC23RedteamingDataset,
72+
) # noqa: F401
3573
from pyrit.datasets.seed_datasets.remote.transphobia_awareness_dataset import ( # noqa: F401
3674
_TransphobiaAwarenessDataset,
3775
)
38-
from pyrit.datasets.seed_datasets.remote.xstest_dataset import _XSTestDataset # noqa: F401
76+
from pyrit.datasets.seed_datasets.remote.vlsu_multimodal_dataset import (
77+
_VLSUMultimodalDataset,
78+
) # noqa: F401
79+
from pyrit.datasets.seed_datasets.remote.xstest_dataset import (
80+
_XSTestDataset,
81+
) # noqa: F401
3982

4083
__all__ = [
4184
"_RemoteDatasetLoader",
85+
"_AegisContentSafetyDataset",
86+
"_AyaRedteamingDataset",
87+
"_BabelscapeAlertDataset",
88+
"_CCPSensitivePromptsDataset",
89+
"_DarkBenchDataset",
90+
"_EquityMedQADataset",
91+
"_ForbiddenQuestionsDataset",
92+
"_HarmBenchDataset",
93+
"_HarmBenchMultimodalDataset",
94+
"_JBBBehaviorsDataset",
95+
"_LibrAIDoNotAnswerDataset",
96+
"_MedSafetyBenchDataset",
97+
"_MLCommonsAILuminateDataset",
98+
"_PKUSafeRLHFDataset",
99+
"_RedTeamSocialBiasDataset",
100+
"_SorryBenchDataset",
101+
"_SOSBenchDataset",
102+
"_TDC23RedteamingDataset",
103+
"_VLSUMultimodalDataset",
104+
"_XSTestDataset",
42105
]
Lines changed: 264 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,264 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT license.
3+
4+
import logging
5+
import uuid
6+
from enum import Enum
7+
from typing import List, Literal, Optional
8+
9+
from pyrit.common.net_utility import make_request_and_raise_if_error_async
10+
from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import (
11+
_RemoteDatasetLoader,
12+
)
13+
from pyrit.models import SeedDataset, SeedPrompt, data_serializer_factory
14+
15+
logger = logging.getLogger(__name__)
16+
17+
18+
class VLSUCategory(Enum):
19+
"""Categories in the ML-VLSU dataset."""
20+
21+
SLURS_HATE_SPEECH = "C1: Slurs, Hate Speech, Hate Symbols"
22+
DISCRIMINATION = "C2: Discrimination and Unequal Treatment"
23+
DRUG_ABUSE = "C3: Drug Abuse"
24+
SELF_HARM = "C4: Self-Harm and Suicide"
25+
ANIMAL_VIOLENCE = "C5: Animal Violence and Gore"
26+
ADULT_EXPLICIT = "C6: Adult Explicit Sexual Material"
27+
ADULT_RACY = "C7: Adult Racy Material"
28+
WARFARE = "C8: Warfare and Armed Conflicts"
29+
INTERPERSONAL_VIOLENCE = "C9: Interpersonal Violence"
30+
WEAPONS = "C10: Weapons and Dangerous Objects"
31+
GORE_GRAPHIC = "C11: Gore and Graphic Content"
32+
TERRORISM = "C12: Terrorism and Violent Extremism"
33+
JAILBREAKS = "C13: Jailbreaks"
34+
INAUTHENTIC_PRACTICES = "C14: Inauthentic Practices/Fraud"
35+
HUMAN_EXPLOITATION = "C15: Human Exploitation"
36+
37+
38+
class _VLSUMultimodalDataset(_RemoteDatasetLoader):
39+
"""
40+
Loader for the ML-VLSU multimodal dataset.
41+
42+
The ML-VLSU dataset contains image-text pairs with human annotations across three dimensions:
43+
image-only, text-only, and joint safety. This dataset is specifically designed to test
44+
emergent multimodal safety issues where harm arises from the combination of safe text
45+
and safe images.
46+
47+
Note: The first call may be slow as images need to be downloaded from remote URLs.
48+
Subsequent calls will be faster since images are cached locally.
49+
50+
Reference: https://arxiv.org/abs/2501.01151
51+
"""
52+
53+
def __init__(
54+
self,
55+
*,
56+
source: str = "https://raw.githubusercontent.com/apple/ml-vlsu/main/data/VLSU.csv",
57+
source_type: Literal["public_url", "file"] = "public_url",
58+
categories: Optional[List[VLSUCategory]] = None,
59+
unsafe_grades: Optional[List[str]] = ["unsafe", "borderline"],
60+
):
61+
"""
62+
Initialize the ML-VLSU multimodal dataset loader.
63+
64+
Args:
65+
source: URL or file path to the VLSU CSV file. Defaults to official repository.
66+
source_type: The type of source ('public_url' or 'file').
67+
categories: List of VLSU categories to filter examples.
68+
If None, all categories are included (default).
69+
unsafe_grades: List of grades considered unsafe (e.g., ['unsafe', 'borderline']).
70+
Prompts are created only when the respective grade matches one of these values.
71+
Defaults to ['unsafe', 'borderline']. Possible options further include 'safe' and 'not_sure'.
72+
73+
Raises:
74+
ValueError: If any of the specified categories are invalid.
75+
"""
76+
self.source = source
77+
self.source_type: Literal["public_url", "file"] = source_type
78+
self.categories = categories
79+
self.unsafe_grades = unsafe_grades
80+
81+
# Validate categories if provided
82+
if categories is not None:
83+
valid_categories = {category.value for category in VLSUCategory}
84+
invalid_categories = (
85+
set(cat.value if isinstance(cat, VLSUCategory) else cat for cat in categories) - valid_categories
86+
)
87+
if invalid_categories:
88+
raise ValueError(f"Invalid VLSU categories: {', '.join(invalid_categories)}")
89+
90+
@property
91+
def dataset_name(self) -> str:
92+
"""Return the dataset name."""
93+
return "ml_vlsu"
94+
95+
async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset:
96+
"""
97+
Fetch ML-VLSU multimodal examples and return as SeedDataset.
98+
99+
Creates prompts based on safety grades:
100+
- Text prompt: if text_grade is 'unsafe' or 'borderline'
101+
- Image prompt: if image_grade is 'unsafe' or 'borderline'
102+
- Combined prompt: if combined_grade is 'unsafe' or 'borderline'
103+
104+
Args:
105+
cache: Whether to cache the fetched dataset. Defaults to True.
106+
107+
Returns:
108+
SeedDataset: A SeedDataset containing the multimodal examples.
109+
110+
Raises:
111+
ValueError: If any example is missing required keys.
112+
"""
113+
logger.info(f"Loading ML-VLSU dataset from {self.source}")
114+
115+
required_keys = {
116+
"prompt",
117+
"web_path",
118+
"combined_category",
119+
"uuid",
120+
"image_grade",
121+
"consensus_text_grade",
122+
"consensus_combined_grade",
123+
"image_category",
124+
"text_category",
125+
}
126+
examples = self._fetch_from_url(
127+
source=self.source,
128+
source_type=self.source_type,
129+
cache=cache,
130+
)
131+
132+
prompts = []
133+
failed_image_count = 0
134+
135+
for example in examples:
136+
missing_keys = required_keys - example.keys()
137+
if missing_keys:
138+
raise ValueError(f"Missing keys in example: {', '.join(missing_keys)}")
139+
140+
text = example.get("prompt")
141+
image_url = example.get("web_path")
142+
143+
# Get grades
144+
text_grade = example.get("consensus_text_grade", "").lower()
145+
image_grade = example.get("image_grade", "").lower()
146+
combined_grade = example.get("consensus_combined_grade", "").lower()
147+
148+
# Get categories
149+
text_category = example.get("text_category")
150+
image_category = example.get("image_category")
151+
combined_category = example.get("combined_category")
152+
153+
# Filter by categories if specified
154+
if self.categories is not None:
155+
category_values = {cat.value for cat in self.categories}
156+
if combined_category not in category_values:
157+
continue
158+
159+
# Only create prompts if combined_grade is unsafe or borderline
160+
if combined_grade not in self.unsafe_grades:
161+
continue
162+
163+
# Generate a shared group_id for the text+image pair
164+
group_id = uuid.uuid4()
165+
166+
try:
167+
local_image_path = await self._fetch_and_save_image_async(image_url, str(group_id))
168+
169+
# Create text prompt (sequence=0, sent first)
170+
text_prompt = SeedPrompt(
171+
value=text,
172+
data_type="text",
173+
name="ML-VLSU Text",
174+
dataset_name=self.dataset_name,
175+
harm_categories=[combined_category],
176+
description="Text component of ML-VLSU multimodal prompt.",
177+
source=self.source,
178+
prompt_group_id=group_id,
179+
sequence=0,
180+
metadata={
181+
"category": combined_category,
182+
"text_grade": text_grade,
183+
"image_grade": image_grade,
184+
"combined_grade": combined_grade,
185+
},
186+
)
187+
188+
# Create image prompt (sequence=1, sent second)
189+
image_prompt = SeedPrompt(
190+
value=local_image_path,
191+
data_type="image_path",
192+
name="ML-VLSU Image",
193+
dataset_name=self.dataset_name,
194+
harm_categories=[combined_category],
195+
description="Image component of ML-VLSU multimodal prompt.",
196+
source=self.source,
197+
prompt_group_id=group_id,
198+
sequence=1,
199+
metadata={
200+
"category": combined_category,
201+
"text_grade": text_grade,
202+
"image_grade": image_grade,
203+
"combined_grade": combined_grade,
204+
"original_image_url": image_url,
205+
},
206+
)
207+
208+
prompts.append(text_prompt)
209+
prompts.append(image_prompt)
210+
211+
except Exception as e:
212+
failed_image_count += 1
213+
logger.warning(f"Failed to fetch image for combined prompt {group_id}: {e}")
214+
215+
if failed_image_count > 0:
216+
logger.warning(f"[ML-VLSU] Skipped {failed_image_count} image(s) due to fetch failures")
217+
218+
logger.info(f"Successfully loaded {len(prompts)} prompts from ML-VLSU dataset")
219+
220+
return SeedDataset(seeds=prompts, dataset_name=self.dataset_name)
221+
222+
async def _fetch_and_save_image_async(self, image_url: str, group_id: str) -> str:
223+
"""
224+
Fetch and save an image from the ML-VLSU dataset.
225+
226+
Args:
227+
image_url: URL to the image.
228+
group_id: Group ID for naming the cached file.
229+
230+
Returns:
231+
Local path to the saved image.
232+
"""
233+
filename = f"ml_vlsu_{group_id}.png"
234+
serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png")
235+
236+
# Return existing path if image already exists
237+
serializer.value = str(serializer._memory.results_path + serializer.data_sub_directory + f"/{filename}")
238+
try:
239+
if await serializer._memory.results_storage_io.path_exists(serializer.value):
240+
return serializer.value
241+
except Exception as e:
242+
logger.warning(f"[ML-VLSU] Failed to check if image for {group_id} exists in cache: {e}")
243+
244+
# Add browser-like headers for better success rate
245+
headers = {
246+
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36",
247+
"Accept": "image/webp,image/apng,image/*,*/*;q=0.8",
248+
"Accept-Language": "en-US,en;q=0.9",
249+
"Accept-Encoding": "gzip, deflate, br",
250+
"DNT": "1",
251+
"Connection": "keep-alive",
252+
"Upgrade-Insecure-Requests": "1",
253+
}
254+
255+
response = await make_request_and_raise_if_error_async(
256+
endpoint_uri=image_url,
257+
method="GET",
258+
headers=headers,
259+
timeout=2.0,
260+
follow_redirects=True,
261+
)
262+
await serializer.save_data(data=response.content, output_filename=filename.replace(".png", ""))
263+
264+
return str(serializer.value)

0 commit comments

Comments
 (0)