|
| 1 | +import pytest |
| 2 | +import torch |
| 3 | +from torch.utils.data import DataLoader, Subset |
| 4 | +from torchvision.datasets import CIFAR10, CIFAR100 |
| 5 | +from torchvision.transforms import ToTensor |
| 6 | +from aros import ( |
| 7 | + LabelChangedDataset, |
| 8 | + get_subsampled_subset, |
| 9 | + get_loaders, |
| 10 | +) |
| 11 | + |
| 12 | +# Set up transformations and datasets for tests |
| 13 | +transform_tensor = ToTensor() |
| 14 | + |
| 15 | +@pytest.fixture |
| 16 | +def cifar10_datasets(): |
| 17 | + trainset = CIFAR10(root='./data', train=True, download=True, transform=transform_tensor) |
| 18 | + testset = CIFAR10(root='./data', train=False, download=True, transform=transform_tensor) |
| 19 | + return trainset, testset |
| 20 | + |
| 21 | +@pytest.fixture |
| 22 | +def cifar100_datasets(): |
| 23 | + trainset = CIFAR100(root='./data', train=True, download=True, transform=transform_tensor) |
| 24 | + testset = CIFAR100(root='./data', train=False, download=True, transform=transform_tensor) |
| 25 | + return trainset, testset |
| 26 | + |
| 27 | +def test_label_changed_dataset(cifar10_datasets): |
| 28 | + _, testset = cifar10_datasets |
| 29 | + new_label = 99 |
| 30 | + relabeled_dataset = LabelChangedDataset(testset, new_label) |
| 31 | + |
| 32 | + assert len(relabeled_dataset) == len(testset), "Relabeled dataset should match the original dataset length" |
| 33 | + |
| 34 | + for img, label in relabeled_dataset: |
| 35 | + assert label == new_label, "All labels should be changed to the new label" |
| 36 | + |
| 37 | +def test_get_subsampled_subset(cifar10_datasets): |
| 38 | + trainset, _ = cifar10_datasets |
| 39 | + subset_ratio = 0.1 |
| 40 | + subset = get_subsampled_subset(trainset, subset_ratio=subset_ratio) |
| 41 | + |
| 42 | + expected_size = int(len(trainset) * subset_ratio) |
| 43 | + assert len(subset) == expected_size, f"Subset size should be {expected_size}" |
| 44 | + |
| 45 | +def test_get_loaders_cifar10(cifar10_datasets): |
| 46 | + train_loader, test_loader, test_set, test_loader_vs_other = get_loaders('cifar10') |
| 47 | + |
| 48 | + assert isinstance(train_loader, DataLoader) |
| 49 | + assert isinstance(test_loader, DataLoader) |
| 50 | + assert isinstance(test_loader_vs_other, DataLoader) |
| 51 | + |
| 52 | + for images, labels in test_loader: |
| 53 | + assert images.shape[0] == 16, "Test loader batch size should be 16" |
| 54 | + break |
| 55 | + |
| 56 | +def test_get_loaders_cifar100(cifar100_datasets): |
| 57 | + train_loader, test_loader, test_set, test_loader_vs_other = get_loaders('cifar100') |
| 58 | + |
| 59 | + assert isinstance(train_loader, DataLoader) |
| 60 | + assert isinstance(test_loader, DataLoader) |
| 61 | + assert isinstance(test_loader_vs_other, DataLoader) |
| 62 | + |
| 63 | + for images, labels in test_loader: |
| 64 | + assert images.shape[0] == 16, "Test loader batch size should be 16" |
| 65 | + break |
| 66 | + |
| 67 | +def test_get_loaders_invalid_dataset(): |
| 68 | + with pytest.raises(ValueError, match="Dataset 'invalid_dataset' is not supported."): |
| 69 | + get_loaders('invalid_dataset') |
0 commit comments