Skip to content

Commit d7526bf

Browse files
committed
Actually send the indices, not labels to datasets
1 parent 20faa24 commit d7526bf

File tree

1 file changed

+15
-9
lines changed

1 file changed

+15
-9
lines changed

utils/load_data.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from torch.utils.data import Dataset, random_split
23

34
from .dataloaders import (
@@ -46,23 +47,28 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
4647
match dataset.lower():
4748
case "usps_0-6":
4849
dataset = USPSDataset0_6
49-
train_samples, test_samples = Downloader.usps(*args, **kwargs)
50-
labels = range(7)
50+
train_labels, test_labels = Downloader.usps(*args, **kwargs)
51+
labels = np.arange(7)
5152
case "usps_7-9":
5253
dataset = USPSH5_Digit_7_9_Dataset
53-
train_samples, test_samples = Downloader.usps(*args, **kwargs)
54-
labels = range(7, 10)
54+
train_labels, test_labels = Downloader.usps(*args, **kwargs)
55+
labels = np.arange(7, 10)
5556
case "mnist_0-3":
5657
dataset = MNISTDataset0_3
57-
train_samples, test_samples = Downloader.mnist(*args, **kwargs)
58-
labels = range(4)
58+
train_labels, test_labels = Downloader.mnist(*args, **kwargs)
59+
labels = np.arange(4)
5960
case _:
6061
raise NotImplementedError(f"Dataset: {dataset} not implemented.")
6162

62-
val_size = kwargs.get("val_size", 0.1)
63+
val_size = kwargs.get("val_size", 0.2)
6364

64-
train_samples = filter_labels(train_samples, labels)
65-
test_samples = filter_labels(test_samples, labels)
65+
# Get the indices of the samples
66+
train_indices = np.arange(len(train_labels))
67+
test_indices = np.arange(len(test_labels))
68+
69+
# Filter the labels to only get indices of the wanted labels
70+
train_samples = train_indices[np.isin(train_labels, labels)]
71+
test_samples = test_indices[np.isin(test_labels, labels)]
6672

6773
train_samples, val_samples = random_split(train_samples, [1 - val_size, val_size])
6874

0 commit comments

Comments
 (0)