Skip to content

Commit 0043e11

Browse files
committed
Wrote the dataset, linked it to main, not tested
1 parent 7ff097a commit 0043e11

File tree

5 files changed

+136
-2
lines changed

5 files changed

+136
-2
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies:
1818
- pytest
1919
- ruff
2020
- scalene
21+
- pickle
2122
- pip:
2223
- torch
2324
- torchvision

main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def main():
7373
"--dataset",
7474
type=str,
7575
default="svhn",
76-
choices=["svhn", "usps_0-6"],
76+
choices=["svhn", "usps_0-6", "mnist_0-3"],
7777
help="Which dataset to train the model on.",
7878
)
7979

utils/dataloaders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
__all__ = ["USPSDataset0_6"]
22

33
from .usps_0_6 import USPSDataset0_6
4+
from .mnist_0_3 import MNISTDataset0_3

utils/dataloaders/mnist_0_3.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from pathlib import Path
2+
3+
from torch.utils.data import Dataset
4+
import numpy as np
5+
import urllib.request
6+
import gzip
7+
import os
8+
9+
10+
11+
class MNISTDataset0_3(Dataset):
12+
"""
13+
A custom dataset class for loading MNIST data, specifically for digits 0 through 3.
14+
Parameters
15+
----------
16+
data_path : Path
17+
The root directory where the MNIST data is stored or will be downloaded.
18+
train : bool, optional
19+
If True, loads the training data, otherwise loads the test data. Default is False.
20+
transform : callable, optional
21+
A function/transform that takes in an image and returns a transformed version. Default is None.
22+
download : bool, optional
23+
If True, downloads the dataset if it is not already present in the specified data_path. Default is False.
24+
Attributes
25+
----------
26+
data_path : Path
27+
The root directory where the MNIST data is stored.
28+
mnist_path : Path
29+
The directory where the MNIST data files are stored.
30+
train : bool
31+
Indicates whether the training data or test data is being used.
32+
transform : callable
33+
A function/transform that takes in an image and returns a transformed version.
34+
download : bool
35+
Indicates whether the dataset should be downloaded if not present.
36+
images_path : Path
37+
The path to the image file (training or test) based on the `train` flag.
38+
labels_path : Path
39+
The path to the label file (training or test) based on the `train` flag.
40+
idx : numpy.ndarray
41+
Indices of the labels that are less than 4.
42+
length : int
43+
The number of samples in the dataset.
44+
Methods
45+
-------
46+
_parse_labels(train)
47+
Parses the labels from the label file.
48+
_chech_is_downloaded()
49+
Checks if the dataset is already downloaded.
50+
_download_data()
51+
Downloads and extracts the MNIST dataset.
52+
__len__()
53+
Returns the number of samples in the dataset.
54+
__getitem__(index)
55+
Returns the image and label at the specified index.
56+
"""
57+
def __init__(self, data_path: Path, train: bool = False, transform=None, download: bool = False,):
58+
super().__init__()
59+
60+
self.data_path = data_path
61+
self.mnist_path = self.data_path / "MNIST"
62+
self.train = train
63+
self.transform = transform
64+
self.download = download
65+
66+
if self.download and not self._chech_is_downloaded():
67+
self._download_data()
68+
69+
self.images_path = self.mnist_path / ("train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte")
70+
self.labels_path = self.mnist_path / ("train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte")
71+
72+
labels = self._parse_labels(train=self.train)
73+
74+
self.idx = np.where(labels < 4)[0]
75+
76+
self.length = len(self.idx)
77+
78+
79+
def _parse_labels(self, train):
80+
with open(self.labels_path, "rb") as f:
81+
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
82+
return data
83+
84+
def _chech_is_downloaded(self):
85+
if self.mnist_path.exists():
86+
required_files = ["train-images-idx3-ubyte", "train-labels-idx1-ubyte", "t10k-images-idx3-ubyte", "t10k-labels-idx1-ubyte"]
87+
if all([(self.mnist_path / file).exists() for file in required_files]):
88+
print("Data already downloaded.")
89+
return True
90+
else:
91+
return False
92+
else:
93+
self.mnist_path.mkdir(parents=True, exist_ok=True)
94+
return False
95+
96+
97+
def _download_data(self):
98+
urls = {
99+
"train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
100+
"train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
101+
"test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
102+
"test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
103+
}
104+
105+
for name, url in urls.items():
106+
file_path = os.path.join(self.mnist_path, url.split("/")[-1])
107+
if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading
108+
urllib.request.urlretrieve(url, file_path)
109+
with gzip.open(file_path, 'rb') as f_in:
110+
with open(file_path.replace(".gz", ""), 'wb') as f_out:
111+
f_out.write(f_in.read())
112+
os.remove(file_path) # Remove compressed file
113+
114+
115+
def __len__(self):
116+
return self.length
117+
118+
def __getitem__(self, index):
119+
with open(self.labels_path, "rb") as f:
120+
f.seek(8 + index) # Jump to the label position
121+
label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label
122+
123+
with open(self.images_path, "rb") as f:
124+
f.seek(16 + index * 28) # Jump to image position
125+
image = np.frombuffer(f.read(28), dtype=np.uint8).reshape(28, 28) # Read image data
126+
127+
if self.transform:
128+
image = self.transform(image)
129+
130+
return image, label

utils/load_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from torch.utils.data import Dataset
22

3-
from .dataloaders import USPSDataset0_6
3+
from .dataloaders import USPSDataset0_6, MNISTDataset0_3
44

55

66
def load_data(dataset: str, *args, **kwargs) -> Dataset:
77
match dataset.lower():
88
case "usps_0-6":
99
return USPSDataset0_6(*args, **kwargs)
10+
case "mnist_0-3":
11+
return MNISTDataset0_3(*args, **kwargs)
1012
case _:
1113
raise ValueError(f"Dataset: {dataset} not implemented.")

0 commit comments

Comments
 (0)