|
1 | | -import gzip |
2 | | -import os |
3 | | -import urllib.request |
4 | 1 | from pathlib import Path |
5 | 2 |
|
6 | 3 | import numpy as np |
7 | | -import torch |
8 | | -from torch.utils.data import Dataset, random_split |
| 4 | +from torch.utils.data import Dataset |
| 5 | +from .datasources import MNIST_SOURCE |
9 | 6 |
|
10 | 7 |
|
11 | 8 | class MNISTDataset0_3(Dataset): |
12 | 9 | """ |
13 | | - A custom dataset class for loading MNIST data, specifically for digits 0 through 3. |
14 | | -
|
| 10 | + A custom Dataset class for loading a subset of the MNIST dataset containing digits 0 to 3. |
15 | 11 | Parameters |
16 | 12 | ---------- |
17 | 13 | data_path : Path |
18 | | - The root directory where the MNIST data is stored or will be downloaded. |
| 14 | + The root directory where the MNIST data is stored. |
| 15 | + sample_ids : list |
| 16 | + A list of indices specifying which samples to load. |
19 | 17 | train : bool, optional |
20 | | - If True, loads the training data, otherwise loads the test data. Default is False. |
| 18 | + If True, load training data, otherwise load test data. Default is False. |
21 | 19 | transform : callable, optional |
22 | | - A function/transform that takes in an image and returns a transformed version. Default is None. |
23 | | - download : bool, optional |
24 | | - If True, downloads the dataset if it is not already present in the specified data_path. Default is False. |
25 | | -
|
| 20 | + A function/transform to apply to the images. Default is None. |
26 | 21 | Attributes |
27 | 22 | ---------- |
28 | 23 | data_path : Path |
29 | 24 | The root directory where the MNIST data is stored. |
30 | 25 | mnist_path : Path |
31 | | - The directory where the MNIST data files are stored. |
| 26 | + The directory where the MNIST dataset is located within the root directory. |
| 27 | + idx : list |
| 28 | + A list of indices specifying which samples to load. |
32 | 29 | train : bool |
33 | | - Indicates whether the training data or test data is being used. |
| 30 | + Indicates whether to load training data or test data. |
34 | 31 | transform : callable |
35 | | - A function/transform that takes in an image and returns a transformed version. |
36 | | - download : bool |
37 | | - Indicates whether the dataset should be downloaded if not present. |
| 32 | + A function/transform to apply to the images. |
| 33 | + num_classes : int |
| 34 | + The number of classes in the dataset (0 to 3). |
38 | 35 | images_path : Path |
39 | | - The path to the image file (training or test) based on the `train` flag. |
| 36 | + The path to the image file (train or test) based on the `train` flag. |
40 | 37 | labels_path : Path |
41 | | - The path to the label file (training or test) based on the `train` flag. |
42 | | - idx : numpy.ndarray |
43 | | - Indices of the labels that are less than 4. |
| 38 | + The path to the label file (train or test) based on the `train` flag. |
44 | 39 | length : int |
45 | 40 | The number of samples in the dataset. |
46 | | -
|
47 | 41 | Methods |
48 | 42 | ------- |
49 | | - _parse_labels(train) |
50 | | - Parses the labels from the label file. |
51 | | - _chech_is_downloaded() |
52 | | - Checks if the dataset is already downloaded. |
53 | | - _download_data() |
54 | | - Downloads and extracts the MNIST dataset. |
55 | 43 | __len__() |
56 | 44 | Returns the number of samples in the dataset. |
57 | 45 | __getitem__(index) |
58 | | - Returns the image and label at the specified index. |
| 46 | + Retrieves the image and label at the specified index. |
59 | 47 | """ |
60 | 48 |
|
61 | 49 | def __init__( |
62 | 50 | self, |
63 | | - split: str, |
64 | | - split_percentage: float, |
65 | 51 | data_path: Path, |
66 | | - download: bool = False, |
| 52 | + sample_ids: list, |
| 53 | + train: bool = False, |
67 | 54 | transform=None, |
68 | 55 | ): |
69 | 56 | super().__init__() |
70 | 57 |
|
71 | 58 | self.data_path = data_path |
72 | 59 | self.mnist_path = self.data_path / "MNIST" |
73 | | - self.split = split |
74 | | - self.split_percentage = split_percentage |
| 60 | + self.idx = sample_ids |
| 61 | + self.train = train |
75 | 62 | self.transform = transform |
76 | | - self.download = download |
77 | 63 | self.num_classes = 4 |
78 | 64 |
|
79 | | - if self.split == "test": |
80 | | - train = False # used to decide whether to load training or test dataset |
81 | | - else: |
82 | | - train = True |
83 | | - |
84 | | - if not self.download and not self._chech_is_downloaded(): |
85 | | - raise ValueError( |
86 | | - "Data not found. Set --download-data=True to download the data." |
87 | | - ) |
88 | | - if self.download and not self._chech_is_downloaded(): |
89 | | - self._download_data() |
90 | | - |
91 | | - self.images_path = self.mnist_path / ( |
92 | | - "train-images-idx3-ubyte" if train else "t10k-images-idx3-ubyte" |
93 | | - ) |
94 | | - self.labels_path = self.mnist_path / ( |
95 | | - "train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte" |
96 | | - ) |
97 | | - |
98 | | - labels = self._parse_labels() |
99 | | - |
100 | | - self.idx = np.where(labels < 4)[0] |
101 | | - |
102 | | - if self.split != "test": |
103 | | - generator1 = torch.Generator().manual_seed(42) |
104 | | - tr, val = random_split( |
105 | | - self.idx, |
106 | | - [1 - self.split_percentage, self.split_percentage], |
107 | | - generator=generator1, |
108 | | - ) |
109 | | - self.idx = tr if self.split == "train" else val |
| 65 | + self.images_path = self.mnist_path / (MNIST_SOURCE["train_images"][1] if train else MNIST_SOURCE["test_images"][1]) |
| 66 | + self.labels_path = self.mnist_path / (MNIST_SOURCE["train_labels"][1] if train else MNIST_SOURCE["test_labels"][1]) |
110 | 67 |
|
111 | 68 | self.length = len(self.idx) |
112 | | - |
113 | | - def _parse_labels(self): |
114 | | - with open(self.labels_path, "rb") as f: |
115 | | - data = np.frombuffer(f.read(), dtype=np.uint8, offset=8) |
116 | | - return data |
117 | | - |
118 | | - def _chech_is_downloaded(self): |
119 | | - if self.mnist_path.exists(): |
120 | | - required_files = [ |
121 | | - "train-images-idx3-ubyte", |
122 | | - "train-labels-idx1-ubyte", |
123 | | - "t10k-images-idx3-ubyte", |
124 | | - "t10k-labels-idx1-ubyte", |
125 | | - ] |
126 | | - if all([(self.mnist_path / file).exists() for file in required_files]): |
127 | | - print("MNIST Dataset already downloaded.") |
128 | | - return True |
129 | | - else: |
130 | | - return False |
131 | | - else: |
132 | | - self.mnist_path.mkdir(parents=True, exist_ok=True) |
133 | | - return False |
134 | | - |
135 | | - def _download_data(self): |
136 | | - urls = { |
137 | | - "train_images": "https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz", |
138 | | - "train_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz", |
139 | | - "test_images": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz", |
140 | | - "test_labels": "https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz", |
141 | | - } |
142 | | - |
143 | | - for name, url in urls.items(): |
144 | | - file_path = os.path.join(self.mnist_path, url.split("/")[-1]) |
145 | | - if not os.path.exists(file_path.replace(".gz", "")): # Avoid re-downloading |
146 | | - urllib.request.urlretrieve(url, file_path) |
147 | | - with gzip.open(file_path, "rb") as f_in: |
148 | | - with open(file_path.replace(".gz", ""), "wb") as f_out: |
149 | | - f_out.write(f_in.read()) |
150 | | - os.remove(file_path) # Remove compressed file |
151 | | - |
| 69 | + |
152 | 70 | def __len__(self): |
153 | 71 | return self.length |
154 | 72 |
|
|
0 commit comments