diff --git a/pyrit/datasets/seed_datasets/remote/__init__.py b/pyrit/datasets/seed_datasets/remote/__init__.py index a4a8098dd..7e1159fd0 100644 --- a/pyrit/datasets/seed_datasets/remote/__init__.py +++ b/pyrit/datasets/seed_datasets/remote/__init__.py @@ -7,36 +7,99 @@ Import concrete implementations to trigger registration. """ -from pyrit.datasets.seed_datasets.remote.aegis_ai_content_safety_dataset import _AegisContentSafetyDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.aya_redteaming_dataset import _AyaRedteamingDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.babelscape_alert_dataset import _BabelscapeAlertDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.ccp_sensitive_prompts_dataset import _CCPSensitivePromptsDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.darkbench_dataset import _DarkBenchDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.equitymedqa_dataset import _EquityMedQADataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.forbidden_questions_dataset import _ForbiddenQuestionsDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.harmbench_dataset import _HarmBenchDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.harmbench_multimodal_dataset import _HarmBenchMultimodalDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.jbb_behaviors_dataset import _JBBBehaviorsDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.librai_do_not_answer_dataset import _LibrAIDoNotAnswerDataset # noqa: F401 +from pyrit.datasets.seed_datasets.remote.aegis_ai_content_safety_dataset import ( + _AegisContentSafetyDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.aya_redteaming_dataset import ( + _AyaRedteamingDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.babelscape_alert_dataset import ( + _BabelscapeAlertDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.ccp_sensitive_prompts_dataset import ( + _CCPSensitivePromptsDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.darkbench_dataset import ( + _DarkBenchDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.equitymedqa_dataset import ( + _EquityMedQADataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.forbidden_questions_dataset import ( + _ForbiddenQuestionsDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.harmbench_dataset import ( + _HarmBenchDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.harmbench_multimodal_dataset import ( + _HarmBenchMultimodalDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.jbb_behaviors_dataset import ( + _JBBBehaviorsDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.librai_do_not_answer_dataset import ( + _LibrAIDoNotAnswerDataset, +) # noqa: F401 from pyrit.datasets.seed_datasets.remote.llm_latent_adversarial_training_dataset import ( # noqa: F401 _LLMLatentAdversarialTrainingDataset, ) -from pyrit.datasets.seed_datasets.remote.medsafetybench_dataset import _MedSafetyBenchDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.mlcommons_ailuminate_dataset import _MLCommonsAILuminateDataset # noqa: F401 +from pyrit.datasets.seed_datasets.remote.medsafetybench_dataset import ( + _MedSafetyBenchDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.mlcommons_ailuminate_dataset import ( + _MLCommonsAILuminateDataset, +) # noqa: F401 from pyrit.datasets.seed_datasets.remote.multilingual_vulnerability_dataset import ( # noqa: F401 _MultilingualVulnerabilityDataset, ) -from pyrit.datasets.seed_datasets.remote.pku_safe_rlhf_dataset import _PKUSafeRLHFDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.red_team_social_bias_dataset import _RedTeamSocialBiasDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import _RemoteDatasetLoader -from pyrit.datasets.seed_datasets.remote.sorry_bench_dataset import _SorryBenchDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.sosbench_dataset import _SOSBenchDataset # noqa: F401 -from pyrit.datasets.seed_datasets.remote.tdc23_redteaming_dataset import _TDC23RedteamingDataset # noqa: F401 +from pyrit.datasets.seed_datasets.remote.pku_safe_rlhf_dataset import ( + _PKUSafeRLHFDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.red_team_social_bias_dataset import ( + _RedTeamSocialBiasDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( + _RemoteDatasetLoader, +) +from pyrit.datasets.seed_datasets.remote.sorry_bench_dataset import ( + _SorryBenchDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.sosbench_dataset import ( + _SOSBenchDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.tdc23_redteaming_dataset import ( + _TDC23RedteamingDataset, +) # noqa: F401 from pyrit.datasets.seed_datasets.remote.transphobia_awareness_dataset import ( # noqa: F401 _TransphobiaAwarenessDataset, ) -from pyrit.datasets.seed_datasets.remote.xstest_dataset import _XSTestDataset # noqa: F401 +from pyrit.datasets.seed_datasets.remote.vlsu_multimodal_dataset import ( + _VLSUMultimodalDataset, +) # noqa: F401 +from pyrit.datasets.seed_datasets.remote.xstest_dataset import ( + _XSTestDataset, +) # noqa: F401 __all__ = [ "_RemoteDatasetLoader", + "_AegisContentSafetyDataset", + "_AyaRedteamingDataset", + "_BabelscapeAlertDataset", + "_CCPSensitivePromptsDataset", + "_DarkBenchDataset", + "_EquityMedQADataset", + "_ForbiddenQuestionsDataset", + "_HarmBenchDataset", + "_HarmBenchMultimodalDataset", + "_JBBBehaviorsDataset", + "_LibrAIDoNotAnswerDataset", + "_MedSafetyBenchDataset", + "_MLCommonsAILuminateDataset", + "_PKUSafeRLHFDataset", + "_RedTeamSocialBiasDataset", + "_SorryBenchDataset", + "_SOSBenchDataset", + "_TDC23RedteamingDataset", + "_VLSUMultimodalDataset", + "_XSTestDataset", ] diff --git a/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py new file mode 100644 index 000000000..be992fb59 --- /dev/null +++ b/pyrit/datasets/seed_datasets/remote/vlsu_multimodal_dataset.py @@ -0,0 +1,264 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import logging +import uuid +from enum import Enum +from typing import List, Literal, Optional + +from pyrit.common.net_utility import make_request_and_raise_if_error_async +from pyrit.datasets.seed_datasets.remote.remote_dataset_loader import ( + _RemoteDatasetLoader, +) +from pyrit.models import SeedDataset, SeedPrompt, data_serializer_factory + +logger = logging.getLogger(__name__) + + +class VLSUCategory(Enum): + """Categories in the ML-VLSU dataset.""" + + SLURS_HATE_SPEECH = "C1: Slurs, Hate Speech, Hate Symbols" + DISCRIMINATION = "C2: Discrimination and Unequal Treatment" + DRUG_ABUSE = "C3: Drug Abuse" + SELF_HARM = "C4: Self-Harm and Suicide" + ANIMAL_VIOLENCE = "C5: Animal Violence and Gore" + ADULT_EXPLICIT = "C6: Adult Explicit Sexual Material" + ADULT_RACY = "C7: Adult Racy Material" + WARFARE = "C8: Warfare and Armed Conflicts" + INTERPERSONAL_VIOLENCE = "C9: Interpersonal Violence" + WEAPONS = "C10: Weapons and Dangerous Objects" + GORE_GRAPHIC = "C11: Gore and Graphic Content" + TERRORISM = "C12: Terrorism and Violent Extremism" + JAILBREAKS = "C13: Jailbreaks" + INAUTHENTIC_PRACTICES = "C14: Inauthentic Practices/Fraud" + HUMAN_EXPLOITATION = "C15: Human Exploitation" + + +class _VLSUMultimodalDataset(_RemoteDatasetLoader): + """ + Loader for the ML-VLSU multimodal dataset. + + The ML-VLSU dataset contains image-text pairs with human annotations across three dimensions: + image-only, text-only, and joint safety. This dataset is specifically designed to test + emergent multimodal safety issues where harm arises from the combination of safe text + and safe images. + + Note: The first call may be slow as images need to be downloaded from remote URLs. + Subsequent calls will be faster since images are cached locally. + + Reference: https://arxiv.org/abs/2501.01151 + """ + + def __init__( + self, + *, + source: str = "https://raw.githubusercontent.com/apple/ml-vlsu/main/data/VLSU.csv", + source_type: Literal["public_url", "file"] = "public_url", + categories: Optional[List[VLSUCategory]] = None, + unsafe_grades: Optional[List[str]] = ["unsafe", "borderline"], + ): + """ + Initialize the ML-VLSU multimodal dataset loader. + + Args: + source: URL or file path to the VLSU CSV file. Defaults to official repository. + source_type: The type of source ('public_url' or 'file'). + categories: List of VLSU categories to filter examples. + If None, all categories are included (default). + unsafe_grades: List of grades considered unsafe (e.g., ['unsafe', 'borderline']). + Prompts are created only when the respective grade matches one of these values. + Defaults to ['unsafe', 'borderline']. Possible options further include 'safe' and 'not_sure'. + + Raises: + ValueError: If any of the specified categories are invalid. + """ + self.source = source + self.source_type: Literal["public_url", "file"] = source_type + self.categories = categories + self.unsafe_grades = unsafe_grades + + # Validate categories if provided + if categories is not None: + valid_categories = {category.value for category in VLSUCategory} + invalid_categories = ( + set(cat.value if isinstance(cat, VLSUCategory) else cat for cat in categories) - valid_categories + ) + if invalid_categories: + raise ValueError(f"Invalid VLSU categories: {', '.join(invalid_categories)}") + + @property + def dataset_name(self) -> str: + """Return the dataset name.""" + return "ml_vlsu" + + async def fetch_dataset(self, *, cache: bool = True) -> SeedDataset: + """ + Fetch ML-VLSU multimodal examples and return as SeedDataset. + + Creates prompts based on safety grades: + - Text prompt: if text_grade is 'unsafe' or 'borderline' + - Image prompt: if image_grade is 'unsafe' or 'borderline' + - Combined prompt: if combined_grade is 'unsafe' or 'borderline' + + Args: + cache: Whether to cache the fetched dataset. Defaults to True. + + Returns: + SeedDataset: A SeedDataset containing the multimodal examples. + + Raises: + ValueError: If any example is missing required keys. + """ + logger.info(f"Loading ML-VLSU dataset from {self.source}") + + required_keys = { + "prompt", + "web_path", + "combined_category", + "uuid", + "image_grade", + "consensus_text_grade", + "consensus_combined_grade", + "image_category", + "text_category", + } + examples = self._fetch_from_url( + source=self.source, + source_type=self.source_type, + cache=cache, + ) + + prompts = [] + failed_image_count = 0 + + for example in examples: + missing_keys = required_keys - example.keys() + if missing_keys: + raise ValueError(f"Missing keys in example: {', '.join(missing_keys)}") + + text = example.get("prompt") + image_url = example.get("web_path") + + # Get grades + text_grade = example.get("consensus_text_grade", "").lower() + image_grade = example.get("image_grade", "").lower() + combined_grade = example.get("consensus_combined_grade", "").lower() + + # Get categories + text_category = example.get("text_category") + image_category = example.get("image_category") + combined_category = example.get("combined_category") + + # Filter by categories if specified + if self.categories is not None: + category_values = {cat.value for cat in self.categories} + if combined_category not in category_values: + continue + + # Only create prompts if combined_grade is unsafe or borderline + if combined_grade not in self.unsafe_grades: + continue + + # Generate a shared group_id for the text+image pair + group_id = uuid.uuid4() + + try: + local_image_path = await self._fetch_and_save_image_async(image_url, str(group_id)) + + # Create text prompt (sequence=0, sent first) + text_prompt = SeedPrompt( + value=text, + data_type="text", + name="ML-VLSU Text", + dataset_name=self.dataset_name, + harm_categories=[combined_category], + description="Text component of ML-VLSU multimodal prompt.", + source=self.source, + prompt_group_id=group_id, + sequence=0, + metadata={ + "category": combined_category, + "text_grade": text_grade, + "image_grade": image_grade, + "combined_grade": combined_grade, + }, + ) + + # Create image prompt (sequence=1, sent second) + image_prompt = SeedPrompt( + value=local_image_path, + data_type="image_path", + name="ML-VLSU Image", + dataset_name=self.dataset_name, + harm_categories=[combined_category], + description="Image component of ML-VLSU multimodal prompt.", + source=self.source, + prompt_group_id=group_id, + sequence=1, + metadata={ + "category": combined_category, + "text_grade": text_grade, + "image_grade": image_grade, + "combined_grade": combined_grade, + "original_image_url": image_url, + }, + ) + + prompts.append(text_prompt) + prompts.append(image_prompt) + + except Exception as e: + failed_image_count += 1 + logger.warning(f"Failed to fetch image for combined prompt {group_id}: {e}") + + if failed_image_count > 0: + logger.warning(f"[ML-VLSU] Skipped {failed_image_count} image(s) due to fetch failures") + + logger.info(f"Successfully loaded {len(prompts)} prompts from ML-VLSU dataset") + + return SeedDataset(seeds=prompts, dataset_name=self.dataset_name) + + async def _fetch_and_save_image_async(self, image_url: str, group_id: str) -> str: + """ + Fetch and save an image from the ML-VLSU dataset. + + Args: + image_url: URL to the image. + group_id: Group ID for naming the cached file. + + Returns: + Local path to the saved image. + """ + filename = f"ml_vlsu_{group_id}.png" + serializer = data_serializer_factory(category="seed-prompt-entries", data_type="image_path", extension="png") + + # Return existing path if image already exists + serializer.value = str(serializer._memory.results_path + serializer.data_sub_directory + f"/{filename}") + try: + if await serializer._memory.results_storage_io.path_exists(serializer.value): + return serializer.value + except Exception as e: + logger.warning(f"[ML-VLSU] Failed to check if image for {group_id} exists in cache: {e}") + + # Add browser-like headers for better success rate + headers = { + "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36", + "Accept": "image/webp,image/apng,image/*,*/*;q=0.8", + "Accept-Language": "en-US,en;q=0.9", + "Accept-Encoding": "gzip, deflate, br", + "DNT": "1", + "Connection": "keep-alive", + "Upgrade-Insecure-Requests": "1", + } + + response = await make_request_and_raise_if_error_async( + endpoint_uri=image_url, + method="GET", + headers=headers, + timeout=2.0, + follow_redirects=True, + ) + await serializer.save_data(data=response.content, output_filename=filename.replace(".png", "")) + + return str(serializer.value) diff --git a/tests/unit/datasets/test_vlsu_multimodal_dataset.py b/tests/unit/datasets/test_vlsu_multimodal_dataset.py new file mode 100644 index 000000000..8af2ee54f --- /dev/null +++ b/tests/unit/datasets/test_vlsu_multimodal_dataset.py @@ -0,0 +1,374 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import uuid +from unittest.mock import patch + +import pytest + +from pyrit.datasets.seed_datasets.remote.vlsu_multimodal_dataset import ( + VLSUCategory, + _VLSUMultimodalDataset, +) +from pyrit.memory import SQLiteMemory +from pyrit.memory.central_memory import CentralMemory +from pyrit.models import SeedDataset + + +class TestVLSUMultimodalDataset: + """Unit tests for _VLSUMultimodalDataset.""" + + @pytest.fixture(autouse=True) + def setup_memory(self): + """Set up memory instance for image downloads.""" + memory = SQLiteMemory() + CentralMemory.set_memory_instance(memory) + yield + CentralMemory.set_memory_instance(None) + + def test_dataset_name(self): + """Test that dataset_name property returns correct value.""" + dataset = _VLSUMultimodalDataset() + assert dataset.dataset_name == "ml_vlsu" + + def test_init_with_categories(self): + """Test initialization with category filtering.""" + categories = [VLSUCategory.SLURS_HATE_SPEECH, VLSUCategory.DISCRIMINATION] + dataset = _VLSUMultimodalDataset(categories=categories) + assert dataset.categories == categories + + def test_init_with_invalid_categories(self): + """Test that invalid categories raise ValueError.""" + with pytest.raises(ValueError, match="Invalid VLSU categories"): + _VLSUMultimodalDataset(categories=["invalid_category"]) + + def test_init_with_unsafe_grades(self): + """Test initialization with custom unsafe grades.""" + dataset = _VLSUMultimodalDataset(unsafe_grades=["unsafe"]) + assert dataset.unsafe_grades == ["unsafe"] + + @pytest.mark.asyncio + async def test_fetch_dataset_combined_unsafe_creates_pair(self): + """Test that text+image pair is created when combined_grade is unsafe.""" + test_uuid = str(uuid.uuid4()) + mock_data = [ + { + "prompt": "Text that becomes unsafe with image", + "web_path": "https://example.com/image.jpg", + "combined_category": "C1: Slurs, Hate Speech, Hate Symbols", + "text_category": "C1: Slurs, Hate Speech, Hate Symbols", + "image_category": "C2: Discrimination and Unequal Treatment", + "uuid": test_uuid, + "consensus_text_grade": "safe", + "image_grade": "safe", + "consensus_combined_grade": "unsafe", + } + ] + + dataset_loader = _VLSUMultimodalDataset() + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + return_value="/fake/path/image.png", + ), + ): + dataset = await dataset_loader.fetch_dataset(cache=False) + + assert isinstance(dataset, SeedDataset) + assert len(dataset.seeds) == 2 # Text + Image pair + + # Find text and image prompts + text_prompt = next(s for s in dataset.seeds if s.data_type == "text") + image_prompt = next(s for s in dataset.seeds if s.data_type == "image_path") + + # Verify they share the same group_id + assert text_prompt.prompt_group_id == image_prompt.prompt_group_id + + # Verify sequence order + assert text_prompt.sequence == 0 + assert image_prompt.sequence == 1 + + # Verify text prompt + assert text_prompt.value == "Text that becomes unsafe with image" + assert text_prompt.name == "ML-VLSU Text" + assert text_prompt.harm_categories == ["C1: Slurs, Hate Speech, Hate Symbols"] + assert text_prompt.metadata["text_grade"] == "safe" + assert text_prompt.metadata["image_grade"] == "safe" + assert text_prompt.metadata["combined_grade"] == "unsafe" + + # Verify image prompt + assert image_prompt.value == "/fake/path/image.png" + assert image_prompt.name == "ML-VLSU Image" + assert image_prompt.harm_categories == ["C1: Slurs, Hate Speech, Hate Symbols"] + assert image_prompt.metadata["original_image_url"] == "https://example.com/image.jpg" + + @pytest.mark.asyncio + async def test_fetch_dataset_combined_borderline_creates_pair(self): + """Test that text+image pair is created when combined_grade is borderline.""" + mock_data = [ + { + "prompt": "Borderline text", + "web_path": "https://example.com/image.jpg", + "combined_category": "C3: Drug Abuse", + "text_category": "C3: Drug Abuse", + "image_category": "C3: Drug Abuse", + "uuid": str(uuid.uuid4()), + "consensus_text_grade": "safe", + "image_grade": "safe", + "consensus_combined_grade": "borderline", + } + ] + + dataset_loader = _VLSUMultimodalDataset() + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + return_value="/fake/path/image.png", + ), + ): + dataset = await dataset_loader.fetch_dataset(cache=False) + + assert len(dataset.seeds) == 2 # Text + Image pair + + text_prompt = next(s for s in dataset.seeds if s.data_type == "text") + image_prompt = next(s for s in dataset.seeds if s.data_type == "image_path") + + assert text_prompt.prompt_group_id == image_prompt.prompt_group_id + assert text_prompt.metadata["combined_grade"] == "borderline" + + @pytest.mark.asyncio + async def test_fetch_dataset_combined_safe_no_prompts(self): + """Test that no prompts are created when combined_grade is safe.""" + mock_data = [ + { + "prompt": "Safe text", + "web_path": "https://example.com/image.jpg", + "combined_category": "C1: Slurs, Hate Speech, Hate Symbols", + "text_category": "C1: Slurs, Hate Speech, Hate Symbols", + "image_category": "C1: Slurs, Hate Speech, Hate Symbols", + "uuid": str(uuid.uuid4()), + "consensus_text_grade": "unsafe", # Even if individual grades are unsafe + "image_grade": "unsafe", + "consensus_combined_grade": "safe", # Combined is safe, so no prompts + } + ] + + dataset_loader = _VLSUMultimodalDataset() + + with patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data): + with pytest.raises(ValueError, match="SeedDataset cannot be empty"): + await dataset_loader.fetch_dataset(cache=False) + + @pytest.mark.asyncio + async def test_fetch_dataset_multiple_pairs(self): + """Test that multiple text+image pairs are created correctly.""" + mock_data = [ + { + "prompt": "First unsafe prompt", + "web_path": "https://example.com/image1.jpg", + "combined_category": "C1: Slurs, Hate Speech, Hate Symbols", + "text_category": "C1: Slurs, Hate Speech, Hate Symbols", + "image_category": "C1: Slurs, Hate Speech, Hate Symbols", + "uuid": str(uuid.uuid4()), + "consensus_text_grade": "safe", + "image_grade": "safe", + "consensus_combined_grade": "unsafe", + }, + { + "prompt": "Second unsafe prompt", + "web_path": "https://example.com/image2.jpg", + "combined_category": "C2: Discrimination and Unequal Treatment", + "text_category": "C2: Discrimination and Unequal Treatment", + "image_category": "C2: Discrimination and Unequal Treatment", + "uuid": str(uuid.uuid4()), + "consensus_text_grade": "safe", + "image_grade": "safe", + "consensus_combined_grade": "borderline", + }, + ] + + dataset_loader = _VLSUMultimodalDataset() + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + return_value="/fake/path/image.png", + ), + ): + dataset = await dataset_loader.fetch_dataset(cache=False) + + assert len(dataset.seeds) == 4 # 2 pairs of text + image + + # Get unique group_ids + group_ids = set(s.prompt_group_id for s in dataset.seeds) + assert len(group_ids) == 2 # Two different pairs + + # Verify each pair has one text and one image + for group_id in group_ids: + pair = [s for s in dataset.seeds if s.prompt_group_id == group_id] + assert len(pair) == 2 + data_types = {s.data_type for s in pair} + assert data_types == {"text", "image_path"} + + @pytest.mark.asyncio + async def test_category_filtering(self): + """Test that category filtering works correctly.""" + mock_data = [ + { + "prompt": "Slur prompt", + "web_path": "https://example.com/image1.jpg", + "combined_category": "C1: Slurs, Hate Speech, Hate Symbols", + "text_category": "C1: Slurs, Hate Speech, Hate Symbols", + "image_category": "C1: Slurs, Hate Speech, Hate Symbols", + "uuid": str(uuid.uuid4()), + "consensus_text_grade": "safe", + "image_grade": "safe", + "consensus_combined_grade": "unsafe", + }, + { + "prompt": "Discrimination prompt", + "web_path": "https://example.com/image2.jpg", + "combined_category": "C2: Discrimination and Unequal Treatment", + "text_category": "C2: Discrimination and Unequal Treatment", + "image_category": "C2: Discrimination and Unequal Treatment", + "uuid": str(uuid.uuid4()), + "consensus_text_grade": "safe", + "image_grade": "safe", + "consensus_combined_grade": "unsafe", + }, + ] + + dataset_loader = _VLSUMultimodalDataset(categories=[VLSUCategory.SLURS_HATE_SPEECH]) + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + return_value="/fake/path/image.png", + ), + ): + dataset = await dataset_loader.fetch_dataset(cache=False) + + # Only the slur category should be included (1 pair = 2 prompts) + assert len(dataset.seeds) == 2 + for seed in dataset.seeds: + assert "C1: Slurs" in str(seed.harm_categories) + + @pytest.mark.asyncio + async def test_handles_failed_image_downloads(self): + """Test that entire pair is skipped when image download fails.""" + mock_data = [ + { + "prompt": "Test prompt", + "web_path": "https://broken-url.com/image.jpg", + "combined_category": "C1: Slurs, Hate Speech, Hate Symbols", + "text_category": "C1: Slurs, Hate Speech, Hate Symbols", + "image_category": "C1: Slurs, Hate Speech, Hate Symbols", + "uuid": str(uuid.uuid4()), + "consensus_text_grade": "safe", + "image_grade": "safe", + "consensus_combined_grade": "unsafe", + } + ] + + dataset_loader = _VLSUMultimodalDataset() + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + side_effect=Exception("Download failed"), + ), + ): + # Both text and image should be skipped when image fails + with pytest.raises(ValueError, match="SeedDataset cannot be empty"): + await dataset_loader.fetch_dataset(cache=False) + + @pytest.mark.asyncio + async def test_custom_unsafe_grades(self): + """Test that custom unsafe_grades parameter works correctly.""" + mock_data = [ + { + "prompt": "Unsafe prompt", + "web_path": "https://example.com/image.jpg", + "combined_category": "C1: Slurs, Hate Speech, Hate Symbols", + "text_category": "C1: Slurs, Hate Speech, Hate Symbols", + "image_category": "C1: Slurs, Hate Speech, Hate Symbols", + "uuid": str(uuid.uuid4()), + "consensus_text_grade": "safe", + "image_grade": "safe", + "consensus_combined_grade": "unsafe", + }, + { + "prompt": "Borderline prompt", + "web_path": "https://example.com/image2.jpg", + "combined_category": "C2: Discrimination and Unequal Treatment", + "text_category": "C2: Discrimination and Unequal Treatment", + "image_category": "C2: Discrimination and Unequal Treatment", + "uuid": str(uuid.uuid4()), + "consensus_text_grade": "safe", + "image_grade": "safe", + "consensus_combined_grade": "borderline", + }, + ] + + # Only include "unsafe", not "borderline" + dataset_loader = _VLSUMultimodalDataset(unsafe_grades=["unsafe"]) + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + return_value="/fake/path/image.png", + ), + ): + dataset = await dataset_loader.fetch_dataset(cache=False) + + # Only the "unsafe" pair should be included + assert len(dataset.seeds) == 2 + for seed in dataset.seeds: + assert seed.metadata["combined_grade"] == "unsafe" + + @pytest.mark.asyncio + async def test_both_prompts_use_combined_category(self): + """Test that both text and image prompts use the combined_category.""" + mock_data = [ + { + "prompt": "Test prompt", + "web_path": "https://example.com/image.jpg", + "combined_category": "C1: Slurs, Hate Speech, Hate Symbols", + "text_category": "C2: Discrimination and Unequal Treatment", + "image_category": "C3: Drug Abuse", + "uuid": str(uuid.uuid4()), + "consensus_text_grade": "safe", + "image_grade": "safe", + "consensus_combined_grade": "unsafe", + } + ] + + dataset_loader = _VLSUMultimodalDataset() + + with ( + patch.object(dataset_loader, "_fetch_from_url", return_value=mock_data), + patch.object( + dataset_loader, + "_fetch_and_save_image_async", + return_value="/fake/path/image.png", + ), + ): + dataset = await dataset_loader.fetch_dataset(cache=False) + + # Both should use combined_category, not their individual categories + for seed in dataset.seeds: + assert seed.harm_categories == ["C1: Slurs, Hate Speech, Hate Symbols"]