|
4 | 4 | This module contains the Dataset class for the USPS dataset with labels 0-6. |
5 | 5 | """ |
6 | 6 |
|
7 | | -import bz2 |
8 | | -import hashlib |
9 | 7 | from pathlib import Path |
10 | | -from tempfile import TemporaryDirectory |
11 | | -from urllib.request import urlretrieve |
12 | 8 |
|
13 | 9 | import h5py as h5 |
14 | 10 | import numpy as np |
15 | 11 | from PIL import Image |
16 | 12 | from torch.utils.data import Dataset |
17 | 13 | from torchvision import transforms |
18 | 14 |
|
19 | | -from .datasources import USPS_SOURCE |
20 | | - |
21 | 15 |
|
22 | 16 | class USPSDataset0_6(Dataset): |
23 | 17 | """ |
@@ -87,178 +81,31 @@ class USPSDataset0_6(Dataset): |
87 | 81 | def __init__( |
88 | 82 | self, |
89 | 83 | data_path: Path, |
| 84 | + sample_ids: list, |
90 | 85 | train: bool = False, |
91 | 86 | transform=None, |
92 | | - download: bool = False, |
93 | 87 | ): |
94 | 88 | super().__init__() |
95 | 89 |
|
96 | 90 | path = data_path if isinstance(data_path, Path) else Path(data_path) |
97 | 91 | self.filepath = path / self.filename |
98 | 92 | self.transform = transform |
99 | 93 | self.mode = "train" if train else "test" |
| 94 | + self.sample_ids = sample_ids |
100 | 95 |
|
101 | | - # Download the dataset if it does not exist in a temporary directory |
102 | | - # to automatically clean up the downloaded file |
103 | | - if download and not self._dataset_ok(): |
104 | | - url, _, checksum = USPS_SOURCE[self.mode] |
105 | | - |
106 | | - print(f"Downloading USPS dataset ({self.mode})...") |
107 | | - self.download(url, self.filepath, checksum, self.mode) |
108 | | - |
109 | | - self.idx = self._index() |
110 | | - |
111 | | - def _dataset_ok(self): |
112 | | - """Check if the dataset file exists and contains the required datasets.""" |
113 | | - |
114 | | - if not self.filepath.exists(): |
115 | | - print(f"Dataset file {self.filepath} does not exist.") |
116 | | - return False |
117 | | - |
118 | | - with h5.File(self.filepath, "r") as f: |
119 | | - for mode in ["train", "test"]: |
120 | | - if mode not in f: |
121 | | - print( |
122 | | - f"Dataset file {self.filepath} is missing the {mode} dataset." |
123 | | - ) |
124 | | - return False |
125 | | - |
126 | | - return True |
127 | | - |
128 | | - def download(self, url, filepath, checksum, mode): |
129 | | - """Download the USPS dataset, and save it as an HDF5 file. |
130 | | -
|
131 | | - Args |
132 | | - ---- |
133 | | - url : str |
134 | | - URL to download the dataset from. |
135 | | - filepath : pathlib.Path |
136 | | - Path to save the downloaded dataset. |
137 | | - checksum : str |
138 | | - MD5 checksum of the downloaded file. |
139 | | - mode : str |
140 | | - Mode of the dataset, either train or test. |
141 | | -
|
142 | | - Raises |
143 | | - ------ |
144 | | - ValueError |
145 | | - If the checksum of the downloaded file does not match the expected checksum. |
146 | | - """ |
147 | | - |
148 | | - def reporthook(blocknum, blocksize, totalsize): |
149 | | - """Report download progress.""" |
150 | | - denom = 1024 * 1024 |
151 | | - readsofar = blocknum * blocksize |
152 | | - if totalsize > 0: |
153 | | - percent = readsofar * 1e2 / totalsize |
154 | | - s = f"\r{int(percent):^3}% {readsofar / denom:.2f} of {totalsize / denom:.2f} MB" |
155 | | - print(s, end="", flush=True) |
156 | | - if readsofar >= totalsize: |
157 | | - print() |
158 | | - |
159 | | - # Download the dataset to a temporary file |
160 | | - with TemporaryDirectory() as tmpdir: |
161 | | - tmpdir = Path(tmpdir) |
162 | | - tmpfile = tmpdir / "usps.bz2" |
163 | | - urlretrieve( |
164 | | - url, |
165 | | - tmpfile, |
166 | | - reporthook=reporthook, |
167 | | - ) |
168 | | - |
169 | | - # For fun we can check the integrity of the downloaded file |
170 | | - if not self.check_integrity(tmpfile, checksum): |
171 | | - errmsg = ( |
172 | | - "The checksum of the downloaded file does " |
173 | | - "not match the expected checksum." |
174 | | - ) |
175 | | - raise ValueError(errmsg) |
176 | | - |
177 | | - # Load the dataset and save it as an HDF5 file |
178 | | - with bz2.open(tmpfile) as fp: |
179 | | - raw = [line.decode().split() for line in fp.readlines()] |
180 | | - |
181 | | - tmp = [[x.split(":")[-1] for x in data[1:]] for data in raw] |
182 | | - |
183 | | - imgs = np.asarray(tmp, dtype=np.float32) |
184 | | - imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8) |
185 | | - |
186 | | - targets = [int(d[0]) - 1 for d in raw] |
187 | | - |
188 | | - with h5.File(self.filepath, "a") as f: |
189 | | - f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32) |
190 | | - f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32) |
191 | | - |
192 | | - @staticmethod |
193 | | - def check_integrity(filepath, checksum): |
194 | | - """Check the integrity of the USPS dataset file. |
195 | | -
|
196 | | - Args |
197 | | - ---- |
198 | | - filepath : pathlib.Path |
199 | | - Path to the USPS dataset file. |
200 | | - checksum : str |
201 | | - MD5 checksum of the dataset file. |
202 | | -
|
203 | | - Returns |
204 | | - ------- |
205 | | - bool |
206 | | - True if the checksum of the file matches the expected checksum, False otherwise |
207 | | - """ |
208 | | - |
209 | | - file_hash = hashlib.md5(filepath.read_bytes()).hexdigest() |
210 | | - |
211 | | - return checksum == file_hash |
212 | | - |
213 | | - def _index(self): |
214 | | - with h5.File(self.filepath, "r") as f: |
215 | | - labels = f[self.mode]["target"][:] |
216 | | - |
217 | | - # Get indices of samples with labels 0-6 |
218 | | - mask = labels <= 6 |
219 | | - idx = np.where(mask)[0] |
| 96 | + def __len__(self): |
| 97 | + return len(self.sample_ids) |
220 | 98 |
|
221 | | - return idx |
| 99 | + def __getitem__(self, id): |
| 100 | + index = self.sample_ids[id] |
222 | 101 |
|
223 | | - def _load_data(self, idx): |
224 | 102 | with h5.File(self.filepath, "r") as f: |
225 | | - data = f[self.mode]["data"][idx].astype(np.uint8) |
226 | | - label = f[self.mode]["target"][idx] |
| 103 | + data = f[self.mode]["data"][index].astype(np.uint8) |
| 104 | + label = f[self.mode]["target"][index] |
227 | 105 |
|
228 | | - return data, label |
229 | | - |
230 | | - def __len__(self): |
231 | | - return len(self.idx) |
232 | | - |
233 | | - def __getitem__(self, idx): |
234 | | - data, target = self._load_data(self.idx[idx]) |
235 | 106 | data = Image.fromarray(data, mode="L") |
236 | 107 |
|
237 | | - # one hot encode the target |
238 | | - target = np.eye(self.num_classes, dtype=np.float32)[target] |
239 | | - |
240 | 108 | if self.transform: |
241 | 109 | data = self.transform(data) |
242 | 110 |
|
243 | | - return data, target |
244 | | - |
245 | | - |
246 | | -if __name__ == "__main__": |
247 | | - # Example usage: |
248 | | - transform = transforms.Compose( |
249 | | - [ |
250 | | - transforms.Resize((16, 16)), |
251 | | - transforms.ToTensor(), |
252 | | - ] |
253 | | - ) |
254 | | - |
255 | | - dataset = USPSDataset0_6( |
256 | | - data_path="data", |
257 | | - train=True, |
258 | | - download=False, |
259 | | - transform=transform, |
260 | | - ) |
261 | | - print(len(dataset)) |
262 | | - data, target = dataset[0] |
263 | | - print(data.shape) |
264 | | - print(target) |
| 111 | + return data, label |
0 commit comments