|
| 1 | +import gzip |
| 2 | +import os |
| 3 | +import urllib.request |
| 4 | +from pathlib import Path |
| 5 | + |
| 6 | +import numpy as np |
| 7 | +from torch.utils.data import Dataset |
| 8 | + |
| 9 | + |
| 10 | +class MNISTDataset0_3(Dataset): |
| 11 | + """ |
| 12 | + A custom dataset class for loading MNIST data, specifically for digits 0 through 3. |
| 13 | + Parameters |
| 14 | + ---------- |
| 15 | + data_path : Path |
| 16 | + The root directory where the MNIST data is stored or will be downloaded. |
| 17 | + train : bool, optional |
| 18 | + If True, loads the training data, otherwise loads the test data. Default is False. |
| 19 | + transform : callable, optional |
| 20 | + A function/transform that takes in an image and returns a transformed version. Default is None. |
| 21 | + download : bool, optional |
| 22 | + If True, downloads the dataset if it is not already present in the specified data_path. Default is False. |
| 23 | + Attributes |
| 24 | + ---------- |
| 25 | + data_path : Path |
| 26 | + The root directory where the MNIST data is stored. |
| 27 | + mnist_path : Path |
| 28 | + The directory where the MNIST data files are stored. |
| 29 | + train : bool |
| 30 | + Indicates whether the training data or test data is being used. |
| 31 | + transform : callable |
| 32 | + A function/transform that takes in an image and returns a transformed version. |
| 33 | + download : bool |
| 34 | + Indicates whether the dataset should be downloaded if not present. |
| 35 | + images_path : Path |
| 36 | + The path to the image file (training or test) based on the `train` flag. |
| 37 | + labels_path : Path |
| 38 | + The path to the label file (training or test) based on the `train` flag. |
| 39 | + idx : numpy.ndarray |
| 40 | + Indices of the labels that are less than 4. |
| 41 | + length : int |
| 42 | + The number of samples in the dataset. |
| 43 | + Methods |
| 44 | + ------- |
| 45 | + _parse_labels(train) |
| 46 | + Parses the labels from the label file. |
| 47 | + _chech_is_downloaded() |
| 48 | + Checks if the dataset is already downloaded. |
| 49 | + _download_data() |
| 50 | + Downloads and extracts the MNIST dataset. |
| 51 | + __len__() |
| 52 | + Returns the number of samples in the dataset. |
| 53 | + __getitem__(index) |
| 54 | + Returns the image and label at the specified index. |
| 55 | + """ |
| 56 | + |
| 57 | + def __init__( |
| 58 | + self, |
| 59 | + data_path: Path, |
| 60 | + train: bool = False, |
| 61 | + transform=None, |
| 62 | + download: bool = False, |
| 63 | + ): |
| 64 | + super().__init__() |
| 65 | + |
| 66 | + self.data_path = data_path |
| 67 | + self.mnist_path = self.data_path / "MNIST" |
| 68 | + self.train = train |
| 69 | + self.transform = transform |
| 70 | + self.download = download |
| 71 | + self.num_classes = 4 |
| 72 | + |
| 73 | + if not self.download and not self._chech_is_downloaded(): |
| 74 | + raise ValueError( |
| 75 | + "Data not found. Set --download-data=True to download the data." |
| 76 | + ) |
| 77 | + if self.download and not self._chech_is_downloaded(): |
| 78 | + self._download_data() |
| 79 | + |
| 80 | + self.images_path = self.mnist_path / ( |
| 81 | + "train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte" |
| 82 | + ) |
| 83 | + self.labels_path = self.mnist_path / ( |
| 84 | + "train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte" |
| 85 | + ) |
| 86 | + |
| 87 | + labels = self._parse_labels(train=self.train) |
| 88 | + |
| 89 | + self.idx = np.where(labels < 4)[0] |
| 90 | + |
| 91 | + self.length = len(self.idx) |
| 92 | + |
| 93 | + def _parse_labels(self, train): |
| 94 | + with open(self.labels_path, "rb") as f: |
| 95 | + data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) |
| 96 | + return data |
| 97 | + |
| 98 | + def _chech_is_downloaded(self): |
| 99 | + if self.mnist_path.exists(): |
| 100 | + required_files = [ |
| 101 | + "train-images-idx3-ubyte", |
| 102 | + "train-labels-idx1-ubyte", |
| 103 | + "t10k-images-idx3-ubyte", |
| 104 | + "t10k-labels-idx1-ubyte", |
| 105 | + ] |
| 106 | + if all([(self.mnist_path / file).exists() for file in required_files]): |
| 107 | + print("MNIST Dataset already downloaded.") |
| 108 | + return True |
| 109 | + else: |
| 110 | + return False |
| 111 | + else: |
| 112 | + self.mnist_path.mkdir(parents=True, exist_ok=True) |
| 113 | + return False |
| 114 | + |
| 115 | + def _download_data(self): |
| 116 | + urls = { |
| 117 | + "train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", |
| 118 | + "train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", |
| 119 | + "test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", |
| 120 | + "test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", |
| 121 | + } |
| 122 | + |
| 123 | + for name, url in urls.items(): |
| 124 | + file_path = os.path.join(self.mnist_path, url.split("/")[-1]) |
| 125 | + if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading |
| 126 | + urllib.request.urlretrieve(url, file_path) |
| 127 | + with gzip.open(file_path, "rb") as f_in: |
| 128 | + with open(file_path.replace(".gz", ""), "wb") as f_out: |
| 129 | + f_out.write(f_in.read()) |
| 130 | + os.remove(file_path) # Remove compressed file |
| 131 | + |
| 132 | + def __len__(self): |
| 133 | + return self.length |
| 134 | + |
| 135 | + def __getitem__(self, index): |
| 136 | + with open(self.labels_path, "rb") as f: |
| 137 | + f.seek(8 + index) # Jump to the label position |
| 138 | + label = int.from_bytes(f.read(1), byteorder="big") # Read 1 byte for label |
| 139 | + |
| 140 | + with open(self.images_path, "rb") as f: |
| 141 | + f.seek(16 + index * 28 * 28) # Jump to image position |
| 142 | + image = np.frombuffer(f.read(28 * 28), dtype=np.uint8).reshape( |
| 143 | + 28, 28 |
| 144 | + ) # Read image data |
| 145 | + |
| 146 | + image = np.expand_dims(image, axis=0) # Add channel dimension |
| 147 | + |
| 148 | + if self.transform: |
| 149 | + image = self.transform(image) |
| 150 | + |
| 151 | + return image, label |
0 commit comments