Skip to content

Commit 22df0a0

Browse files
committed
Updated dataloader to fit with MNIST 4-9
1 parent daf82d6 commit 22df0a0

File tree

1 file changed

+45
-55
lines changed

1 file changed

+45
-55
lines changed

utils/dataloaders/mnist_4_9.py

Lines changed: 45 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,54 @@
1-
import gzip
2-
import os
3-
import urllib.request as ur
41
from pathlib import Path
52

3+
import numpy as np
64
from torch.utils.data import Dataset
75

6+
from .datasources import MNIST_SOURCE
87

98
class MNIST_4_9(Dataset):
10-
def __init__(self, datapath: Path, train: bool = False, download: bool = False):
9+
"""
10+
MNIST dataset of numbers 4-9.
11+
12+
Parameters
13+
----------
14+
data_path : Path
15+
Root directory where MNIST dataset is stored
16+
sample_ids : np.ndarray
17+
Array of indices spcifying which samples to load. This determines the samples used by the dataloader.
18+
train : bool, optional
19+
Whether to train the model or not, by default False
20+
"""
21+
def __init__(self, data_path: Path, sample_ids: np.ndarray, train: bool = False):
1122
super.__init__()
12-
self.datapath = datapath
13-
self.mnist_path = self.datapath / "MNIST"
23+
self.data_path = data_path
24+
self.mnist_path = self.data_path / "MNIST"
25+
self.samples = sample_ids
1426
self.train = train
15-
self.download = download
16-
self.num_classes: int = 6
17-
18-
if not self.download and not self._already_downloaded():
19-
raise FileNotFoundError(
20-
"Data files are not found. Set --download-data=True to download the data"
21-
)
22-
if self.download and not self._already_downloaded():
23-
self._download()
24-
25-
def _download(self):
26-
urls: dict = {
27-
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
28-
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
29-
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
30-
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
31-
}
32-
33-
for url in urls.values():
34-
file_path: Path = os.path.join(self.mnist_path, url.split("/")[-1])
35-
file_name: Path = file_path.replace(".gz", "")
36-
if os.path.exists(file_name):
37-
print(f"File: {file_name} already downloaded")
38-
else:
39-
print(f"File: {file_name} is downloading...")
40-
ur.urlretrieve(url, file_path) # Download file
41-
with gzip.open(file_path, "rb") as infile:
42-
with open(file_name, "wb") as outfile:
43-
outfile.write(infile.read()) # Write from url to local file
44-
os.remove(file_path) # remove .gz file
45-
46-
def _already_downloaded(self):
47-
if self.mnist_path.exists():
48-
required_files: list = [
49-
"train-images-idx3-ubyte",
50-
"train-labels-idx1-ubyte",
51-
"t10k-images-idx3-ubyte",
52-
"t10k-labels-idx1-ubyte",
53-
]
54-
return all([(self.mnist_path / file).exists() for file in required_files])
55-
56-
else:
57-
self.mnist_path.mkdir(parents=True, exist_ok=True)
58-
return False
59-
27+
28+
self.images_path = self.mnist_path / (
29+
MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1]
30+
)
31+
self.labels_path = self.mnist_path / (
32+
MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1]
33+
)
34+
35+
6036
def __len__(self):
61-
pass
62-
63-
def __getitem__(self):
64-
pass
37+
return len(self.samples)
38+
39+
def __getitem__(self, idx):
40+
with open(self.labels_path, "rb") as labelfile:
41+
label_pos = 8 + self.sample[idx]
42+
labelfile.seek(label_pos)
43+
label = int.from_bytes(labelfile.read(1), byteorder="big")
44+
45+
with open(self.images_path, "rb") as imagefile:
46+
image_pos = 16 + self.samples[idx] * 28 * 28
47+
imagefile.seek(image_pos)
48+
image = np.frombuffer(imagefile.read(28 * 28), dtype=np.uint8).reshape(
49+
28, 28
50+
)
51+
52+
image = np.expand_dims(image, axis=0) # Channel
53+
54+
return image, label

0 commit comments

Comments
 (0)