Skip to content

Commit 5dbe8ca

Browse files
committed
Updated load_data.py to fit with one MNIST dataloader, but to use both MNIST cases
1 parent 6972701 commit 5dbe8ca

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

utils/load_data.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from .dataloaders import (
55
Downloader,
6-
MNISTDataset0_3,
6+
MNISTDataset,
77
SVHNDataset,
88
USPSDataset0_6,
99
USPSH5_Digit_7_9_Dataset,
@@ -57,15 +57,17 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
5757
train_labels, test_labels = downloader.usps(data_dir=data_dir)
5858
labels = np.arange(7, 10)
5959
case "mnist_0-3":
60-
dataset = MNISTDataset0_3
60+
dataset = MNISTDataset
6161
train_labels, test_labels = downloader.mnist(data_dir=data_dir)
6262
labels = np.arange(4)
6363
case "svhn":
6464
dataset = SVHNDataset
6565
train_labels, test_labels = downloader.svhn(data_dir=data_dir)
6666
labels = np.arange(10)
6767
case "mnist_4-9":
68-
raise NotImplementedError("MNIST 4-9 dataset not yet implemented.")
68+
dataset = MNISTDataset
69+
train_labels, test_labels = downloader.mnist(data_dir=data_dir)
70+
labels = np.arange(4,10)
6971
case _:
7072
raise NotImplementedError(f"Dataset: {dataset} not implemented.")
7173

0 commit comments

Comments
 (0)