Skip to content

Commit 15c99ea

Browse files
committed
added MNIST downloader, adjusted minor thinks for the code to run
1 parent ad15940 commit 15c99ea

File tree

5 files changed

+101
-121
lines changed

5 files changed

+101
-121
lines changed

main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ def main():
4141

4242
traindata, validata, testdata = load_data(
4343
args.dataset,
44-
data_path=args.datafolder,
44+
data_dir=args.datafolder,
4545
transform=transform,
46-
download=args.download_data,
4746
val_size=args.val_size,
4847
)
4948

utils/dataloaders/datasources.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,22 @@
1717
"8ea070ee2aca1ac39742fdd1ef5ed118",
1818
],
1919
}
20+
21+
MNIST_SOURCE = {
22+
"train_images": ["https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
23+
"train-images-idx3-ubyte",
24+
None
25+
],
26+
"train_labels": ["https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
27+
"train-labels-idx1-ubyte",
28+
None
29+
],
30+
"test_images": ["https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
31+
"t10k-images-idx3-ubyte",
32+
None
33+
],
34+
"test_labels": ["https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
35+
"t10k-labels-idx1-ubyte",
36+
None
37+
],
38+
}

utils/dataloaders/download.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import bz2
22
import hashlib
3+
import os
4+
import gzip
35
from pathlib import Path
46
from tempfile import TemporaryDirectory
57
from urllib.request import urlretrieve
68

79
import h5py as h5
810
import numpy as np
911

10-
from .datasources import USPS_SOURCE
12+
from .datasources import USPS_SOURCE, MNIST_SOURCE
1113

1214

1315
class Downloader:
@@ -38,7 +40,47 @@ class Downloader:
3840
"""
3941

4042
def mnist(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
41-
raise NotImplementedError("MNIST download not implemented yet")
43+
def _chech_is_downloaded(path: Path) -> bool:
44+
path = path / "MNIST"
45+
if path.exists():
46+
required_files = [MNIST_SOURCE[key][1] for key in MNIST_SOURCE.keys()]
47+
if all([(path / file).exists() for file in required_files]):
48+
print("MNIST Dataset already downloaded.")
49+
return True
50+
else:
51+
return False
52+
else:
53+
path.mkdir(parents=True, exist_ok=True)
54+
return False
55+
56+
def _download_data(path: Path) -> None:
57+
urls = {key: MNIST_SOURCE[key][0] for key in MNIST_SOURCE.keys()}
58+
59+
for name, url in urls.items():
60+
file_path = os.path.join(path, url.split("/")[-1])
61+
if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading
62+
urlretrieve(url, file_path)
63+
with gzip.open(file_path, "rb") as f_in:
64+
with open(file_path.replace(".gz", ""), "wb") as f_out:
65+
f_out.write(f_in.read())
66+
os.remove(file_path) # Remove compressed file
67+
68+
def _get_labels(path: Path) -> np.ndarray:
69+
with open(path, "rb") as f:
70+
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
71+
return data
72+
73+
if not _chech_is_downloaded(data_dir):
74+
_download_data(data_dir)
75+
76+
train_labels_path = data_dir / "MNIST" / MNIST_SOURCE["train_labels"][1]
77+
test_labels_path = data_dir / "MNIST" / MNIST_SOURCE["test_labels"][1]
78+
79+
train_labels = _get_labels(train_labels_path)
80+
test_labels = _get_labels(test_labels_path)
81+
82+
return train_labels, test_labels
83+
4284

4385
def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
4486
raise NotImplementedError("SVHN download not implemented yet")

utils/dataloaders/mnist_0_3.py

Lines changed: 25 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -1,154 +1,72 @@
1-
import gzip
2-
import os
3-
import urllib.request
41
from pathlib import Path
52

63
import numpy as np
7-
import torch
8-
from torch.utils.data import Dataset, random_split
4+
from torch.utils.data import Dataset
5+
from .datasources import MNIST_SOURCE
96

107

118
class MNISTDataset0_3(Dataset):
129
"""
13-
A custom dataset class for loading MNIST data, specifically for digits 0 through 3.
14-
10+
A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3.
1511
Parameters
1612
----------
1713
data_path : Path
18-
The root directory where the MNIST data is stored or will be downloaded.
14+
The root directory where the MNIST data is stored.
15+
sample_ids : list
16+
A list of indices specifying which samples to load.
1917
train : bool, optional
20-
If True, loads the training data, otherwise loads the test data. Default is False.
18+
If True, load training data, otherwise load test data. Default is False.
2119
transform : callable, optional
22-
A function/transform that takes in an image and returns a transformed version. Default is None.
23-
download : bool, optional
24-
If True, downloads the dataset if it is not already present in the specified data_path. Default is False.
25-
20+
A function/transform to apply to the images. Default is None.
2621
Attributes
2722
----------
2823
data_path : Path
2924
The root directory where the MNIST data is stored.
3025
mnist_path : Path
31-
The directory where the MNIST data files are stored.
26+
The directory where the MNIST dataset is located within the root directory.
27+
idx : list
28+
A list of indices specifying which samples to load.
3229
train : bool
33-
Indicates whether the training data or test data is being used.
30+
Indicates whether to load training data or test data.
3431
transform : callable
35-
A function/transform that takes in an image and returns a transformed version.
36-
download : bool
37-
Indicates whether the dataset should be downloaded if not present.
32+
A function/transform to apply to the images.
33+
num_classes : int
34+
The number of classes in the dataset (0 to 3).
3835
images_path : Path
39-
The path to the image file (training or test) based on the `train` flag.
36+
The path to the image file (train or test) based on the `train` flag.
4037
labels_path : Path
41-
The path to the label file (training or test) based on the `train` flag.
42-
idx : numpy.ndarray
43-
Indices of the labels that are less than 4.
38+
The path to the label file (train or test) based on the `train` flag.
4439
length : int
4540
The number of samples in the dataset.
46-
4741
Methods
4842
-------
49-
_parse_labels(train)
50-
Parses the labels from the label file.
51-
_chech_is_downloaded()
52-
Checks if the dataset is already downloaded.
53-
_download_data()
54-
Downloads and extracts the MNIST dataset.
5543
__len__()
5644
Returns the number of samples in the dataset.
5745
__getitem__(index)
58-
Returns the image and label at the specified index.
46+
Retrieves the image and label at the specified index.
5947
"""
6048

6149
def __init__(
6250
self,
63-
split: str,
64-
split_percentage: float,
6551
data_path: Path,
66-
download: bool = False,
52+
sample_ids: list,
53+
train: bool = False,
6754
transform=None,
6855
):
6956
super().__init__()
7057

7158
self.data_path = data_path
7259
self.mnist_path = self.data_path / "MNIST"
73-
self.split = split
74-
self.split_percentage = split_percentage
60+
self.idx = sample_ids
61+
self.train = train
7562
self.transform = transform
76-
self.download = download
7763
self.num_classes = 4
7864

79-
if self.split == "test":
80-
train = False # used to decide whether to load training or test dataset
81-
else:
82-
train = True
83-
84-
if not self.download and not self._chech_is_downloaded():
85-
raise ValueError(
86-
"Data not found. Set --download-data=True to download the data."
87-
)
88-
if self.download and not self._chech_is_downloaded():
89-
self._download_data()
90-
91-
self.images_path = self.mnist_path / (
92-
"train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte"
93-
)
94-
self.labels_path = self.mnist_path / (
95-
"train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte"
96-
)
97-
98-
labels = self._parse_labels()
99-
100-
self.idx = np.where(labels < 4)[0]
101-
102-
if self.split != "test":
103-
generator1 = torch.Generator().manual_seed(42)
104-
tr, val = random_split(
105-
self.idx,
106-
[1 - self.split_percentage, self.split_percentage],
107-
generator=generator1,
108-
)
109-
self.idx = tr if self.split == "train" else val
65+
self.images_path = self.mnist_path / (MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1])
66+
self.labels_path = self.mnist_path / (MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1])
11067

11168
self.length = len(self.idx)
112-
113-
def _parse_labels(self):
114-
with open(self.labels_path, "rb") as f:
115-
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
116-
return data
117-
118-
def _chech_is_downloaded(self):
119-
if self.mnist_path.exists():
120-
required_files = [
121-
"train-images-idx3-ubyte",
122-
"train-labels-idx1-ubyte",
123-
"t10k-images-idx3-ubyte",
124-
"t10k-labels-idx1-ubyte",
125-
]
126-
if all([(self.mnist_path / file).exists() for file in required_files]):
127-
print("MNIST Dataset already downloaded.")
128-
return True
129-
else:
130-
return False
131-
else:
132-
self.mnist_path.mkdir(parents=True, exist_ok=True)
133-
return False
134-
135-
def _download_data(self):
136-
urls = {
137-
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
138-
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
139-
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
140-
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
141-
}
142-
143-
for name, url in urls.items():
144-
file_path = os.path.join(self.mnist_path, url.split("/")[-1])
145-
if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading
146-
urllib.request.urlretrieve(url, file_path)
147-
with gzip.open(file_path, "rb") as f_in:
148-
with open(file_path.replace(".gz", ""), "wb") as f_out:
149-
f_out.write(f_in.read())
150-
os.remove(file_path) # Remove compressed file
151-
69+
15270
def __len__(self):
15371
return self.length
15472

utils/load_data.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -43,19 +43,21 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
4343
>>> len(train), len(val), len(test)
4444
(4914, 546, 1782)
4545
"""
46-
46+
downloader = Downloader()
47+
data_dir = kwargs.get("data_dir")
48+
transform = kwargs.get("transform")
4749
match dataset.lower():
4850
case "usps_0-6":
4951
dataset = USPSDataset0_6
50-
train_labels, test_labels = Downloader.usps(*args, **kwargs)
52+
train_labels, test_labels = downloader.usps(data_dir=data_dir)
5153
labels = np.arange(7)
5254
case "usps_7-9":
5355
dataset = USPSH5_Digit_7_9_Dataset
54-
train_labels, test_labels = Downloader.usps(*args, **kwargs)
56+
train_labels, test_labels = downloader.usps(data_dir=data_dir)
5557
labels = np.arange(7, 10)
5658
case "mnist_0-3":
5759
dataset = MNISTDataset0_3
58-
train_labels, test_labels = Downloader.mnist(*args, **kwargs)
60+
train_labels, test_labels = downloader.mnist(data_dir=data_dir)
5961
labels = np.arange(4)
6062
case _:
6163
raise NotImplementedError(f"Dataset: {dataset} not implemented.")
@@ -73,24 +75,24 @@ def load_data(dataset: str, *args, **kwargs) -> tuple:
7375
train_samples, val_samples = random_split(train_samples, [1 - val_size, val_size])
7476

7577
train = dataset(
76-
*args,
78+
data_path=data_dir,
7779
sample_ids=train_samples,
7880
train=True,
79-
**kwargs,
81+
transform=transform,
8082
)
8183

8284
val = dataset(
83-
*args,
85+
data_path=data_dir,
8486
sample_ids=val_samples,
8587
train=True,
86-
**kwargs,
88+
transform=transform,
8789
)
8890

8991
test = dataset(
90-
*args,
92+
data_path=data_dir,
9193
sample_ids=test_samples,
9294
train=False,
93-
**kwargs,
95+
transform=transform,
9496
)
9597

9698
return train, val, test

0 commit comments

Comments
 (0)