|
| 1 | +import numpy as np |
1 | 2 | from torch.utils.data import Dataset, random_split |
2 | 3 |
|
3 | 4 | from .dataloaders import ( |
@@ -46,23 +47,28 @@ def load_data(dataset: str, *args, **kwargs) -> tuple: |
46 | 47 | match dataset.lower(): |
47 | 48 | case "usps_0-6": |
48 | 49 | dataset = USPSDataset0_6 |
49 | | - train_samples, test_samples = Downloader.usps(*args, **kwargs) |
50 | | - labels = range(7) |
| 50 | + train_labels, test_labels = Downloader.usps(*args, **kwargs) |
| 51 | + labels = np.arange(7) |
51 | 52 | case "usps_7-9": |
52 | 53 | dataset = USPSH5_Digit_7_9_Dataset |
53 | | - train_samples, test_samples = Downloader.usps(*args, **kwargs) |
54 | | - labels = range(7, 10) |
| 54 | + train_labels, test_labels = Downloader.usps(*args, **kwargs) |
| 55 | + labels = np.arange(7, 10) |
55 | 56 | case "mnist_0-3": |
56 | 57 | dataset = MNISTDataset0_3 |
57 | | - train_samples, test_samples = Downloader.mnist(*args, **kwargs) |
58 | | - labels = range(4) |
| 58 | + train_labels, test_labels = Downloader.mnist(*args, **kwargs) |
| 59 | + labels = np.arange(4) |
59 | 60 | case _: |
60 | 61 | raise NotImplementedError(f"Dataset: {dataset} not implemented.") |
61 | 62 |
|
62 | | - val_size = kwargs.get("val_size", 0.1) |
| 63 | + val_size = kwargs.get("val_size", 0.2) |
63 | 64 |
|
64 | | - train_samples = filter_labels(train_samples, labels) |
65 | | - test_samples = filter_labels(test_samples, labels) |
| 65 | + # Get the indices of the samples |
| 66 | + train_indices = np.arange(len(train_labels)) |
| 67 | + test_indices = np.arange(len(test_labels)) |
| 68 | + |
| 69 | + # Filter the labels to only get indices of the wanted labels |
| 70 | + train_samples = train_indices[np.isin(train_labels, labels)] |
| 71 | + test_samples = test_indices[np.isin(test_labels, labels)] |
66 | 72 |
|
67 | 73 | train_samples, val_samples = random_split(train_samples, [1 - val_size, val_size]) |
68 | 74 |
|
|
0 commit comments