Skip to content

Commit eb96039

Browse files
committed
Added mnist_4-9 to load_data.py
1 parent dd96500 commit eb96039

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

CollaborativeCoding/load_data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .dataloaders import (
55
Downloader,
66
MNISTDataset0_3,
7+
MNISTDataset4_9,
78
SVHNDataset,
89
USPSDataset0_6,
910
USPSH5_Digit_7_9_Dataset,
@@ -65,7 +66,9 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
6566
train_labels, test_labels = downloader.svhn(data_dir=data_dir)
6667
labels = np.arange(10)
6768
case "mnist_4-9":
68-
raise NotImplementedError("MNIST 4-9 dataset not yet implemented.")
69+
dataset = MNISTDataset4_9
70+
train_labels, test_labels = downloader.mnist(data_dir=data_dir)
71+
labels = np.arange(4,10)
6972
case _:
7073
raise NotImplementedError(f"Dataset: {dataset} not implemented.")
7174

0 commit comments

Comments
 (0)