Skip to content

Commit 6972701

Browse files
committed
Made one MNIST dataloader since we pass in sample_ids
1 parent daf82d6 commit 6972701

File tree

3 files changed

+4
-69
lines changed

3 files changed

+4
-69
lines changed

utils/dataloaders/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
__all__ = [
22
"USPSDataset0_6",
33
"USPSH5_Digit_7_9_Dataset",
4-
"MNISTDataset0_3",
4+
"MNISTDataset",
55
"Downloader",
66
"SVHNDataset",
77
]
88

99
from .download import Downloader
10-
from .mnist_0_3 import MNISTDataset0_3
10+
from .mnist import MNISTDataset
1111
from .svhn import SVHNDataset
1212
from .usps_0_6 import USPSDataset0_6
1313
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
from .datasources import MNIST_SOURCE
77

88

9-
class MNISTDataset0_3(Dataset):
9+
class MNISTDataset(Dataset):
1010
"""
11-
A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3.
11+
A custom Dataset class for loading a subset of the MNIST dataset
1212
Parameters
1313
----------
1414
data_path : Path
@@ -61,7 +61,6 @@ def __init__(
6161
self.idx = sample_ids
6262
self.train = train
6363
self.transform = transform
64-
self.num_classes = 4
6564

6665
self.images_path = self.mnist_path / (
6766
MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1]

utils/dataloaders/mnist_4_9.py

Lines changed: 0 additions & 64 deletions
This file was deleted.

0 commit comments

Comments
 (0)