Skip to content

Commit a9e2cad

Browse files
committed
Remove downloading logic from USPS dataset
1 parent a58e495 commit a9e2cad

File tree

1 file changed

+9
-162
lines changed

1 file changed

+9
-162
lines changed

utils/dataloaders/usps_0_6.py

Lines changed: 9 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,14 @@
44
This module contains the Dataset class for the USPS dataset with labels 0-6.
55
"""
66

7-
import bz2
8-
import hashlib
97
from pathlib import Path
10-
from tempfile import TemporaryDirectory
11-
from urllib.request import urlretrieve
128

139
import h5py as h5
1410
import numpy as np
1511
from PIL import Image
1612
from torch.utils.data import Dataset
1713
from torchvision import transforms
1814

19-
from .datasources import USPS_SOURCE
20-
2115

2216
class USPSDataset0_6(Dataset):
2317
"""
@@ -87,178 +81,31 @@ class USPSDataset0_6(Dataset):
8781
def __init__(
8882
self,
8983
data_path: Path,
84+
sample_ids: list,
9085
train: bool = False,
9186
transform=None,
92-
download: bool = False,
9387
):
9488
super().__init__()
9589

9690
path = data_path if isinstance(data_path, Path) else Path(data_path)
9791
self.filepath = path / self.filename
9892
self.transform = transform
9993
self.mode = "train" if train else "test"
94+
self.sample_ids = sample_ids
10095

101-
# Download the dataset if it does not exist in a temporary directory
102-
# to automatically clean up the downloaded file
103-
if download and not self._dataset_ok():
104-
url, _, checksum = USPS_SOURCE[self.mode]
105-
106-
print(f"Downloading USPS dataset ({self.mode})...")
107-
self.download(url, self.filepath, checksum, self.mode)
108-
109-
self.idx = self._index()
110-
111-
def _dataset_ok(self):
112-
"""Check if the dataset file exists and contains the required datasets."""
113-
114-
if not self.filepath.exists():
115-
print(f"Dataset file {self.filepath} does not exist.")
116-
return False
117-
118-
with h5.File(self.filepath, "r") as f:
119-
for mode in ["train", "test"]:
120-
if mode not in f:
121-
print(
122-
f"Dataset file {self.filepath} is missing the {mode} dataset."
123-
)
124-
return False
125-
126-
return True
127-
128-
def download(self, url, filepath, checksum, mode):
129-
"""Download the USPS dataset, and save it as an HDF5 file.
130-
131-
Args
132-
----
133-
url : str
134-
URL to download the dataset from.
135-
filepath : pathlib.Path
136-
Path to save the downloaded dataset.
137-
checksum : str
138-
MD5 checksum of the downloaded file.
139-
mode : str
140-
Mode of the dataset, either train or test.
141-
142-
Raises
143-
------
144-
ValueError
145-
If the checksum of the downloaded file does not match the expected checksum.
146-
"""
147-
148-
def reporthook(blocknum, blocksize, totalsize):
149-
"""Report download progress."""
150-
denom = 1024 * 1024
151-
readsofar = blocknum * blocksize
152-
if totalsize > 0:
153-
percent = readsofar * 1e2 / totalsize
154-
s = f"\r{int(percent):^3}% {readsofar / denom:.2f} of {totalsize / denom:.2f} MB"
155-
print(s, end="", flush=True)
156-
if readsofar >= totalsize:
157-
print()
158-
159-
# Download the dataset to a temporary file
160-
with TemporaryDirectory() as tmpdir:
161-
tmpdir = Path(tmpdir)
162-
tmpfile = tmpdir / "usps.bz2"
163-
urlretrieve(
164-
url,
165-
tmpfile,
166-
reporthook=reporthook,
167-
)
168-
169-
# For fun we can check the integrity of the downloaded file
170-
if not self.check_integrity(tmpfile, checksum):
171-
errmsg = (
172-
"The checksum of the downloaded file does "
173-
"not match the expected checksum."
174-
)
175-
raise ValueError(errmsg)
176-
177-
# Load the dataset and save it as an HDF5 file
178-
with bz2.open(tmpfile) as fp:
179-
raw = [line.decode().split() for line in fp.readlines()]
180-
181-
tmp = [[x.split(":")[-1] for x in data[1:]] for data in raw]
182-
183-
imgs = np.asarray(tmp, dtype=np.float32)
184-
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
185-
186-
targets = [int(d[0]) - 1 for d in raw]
187-
188-
with h5.File(self.filepath, "a") as f:
189-
f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32)
190-
f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32)
191-
192-
@staticmethod
193-
def check_integrity(filepath, checksum):
194-
"""Check the integrity of the USPS dataset file.
195-
196-
Args
197-
----
198-
filepath : pathlib.Path
199-
Path to the USPS dataset file.
200-
checksum : str
201-
MD5 checksum of the dataset file.
202-
203-
Returns
204-
-------
205-
bool
206-
True if the checksum of the file matches the expected checksum, False otherwise
207-
"""
208-
209-
file_hash = hashlib.md5(filepath.read_bytes()).hexdigest()
210-
211-
return checksum == file_hash
212-
213-
def _index(self):
214-
with h5.File(self.filepath, "r") as f:
215-
labels = f[self.mode]["target"][:]
216-
217-
# Get indices of samples with labels 0-6
218-
mask = labels <= 6
219-
idx = np.where(mask)[0]
96+
def __len__(self):
97+
return len(self.sample_ids)
22098

221-
return idx
99+
def __getitem__(self, id):
100+
index = self.sample_ids[id]
222101

223-
def _load_data(self, idx):
224102
with h5.File(self.filepath, "r") as f:
225-
data = f[self.mode]["data"][idx].astype(np.uint8)
226-
label = f[self.mode]["target"][idx]
103+
data = f[self.mode]["data"][index].astype(np.uint8)
104+
label = f[self.mode]["target"][index]
227105

228-
return data, label
229-
230-
def __len__(self):
231-
return len(self.idx)
232-
233-
def __getitem__(self, idx):
234-
data, target = self._load_data(self.idx[idx])
235106
data = Image.fromarray(data, mode="L")
236107

237-
# one hot encode the target
238-
target = np.eye(self.num_classes, dtype=np.float32)[target]
239-
240108
if self.transform:
241109
data = self.transform(data)
242110

243-
return data, target
244-
245-
246-
if __name__ == "__main__":
247-
# Example usage:
248-
transform = transforms.Compose(
249-
[
250-
transforms.Resize((16, 16)),
251-
transforms.ToTensor(),
252-
]
253-
)
254-
255-
dataset = USPSDataset0_6(
256-
data_path="data",
257-
train=True,
258-
download=False,
259-
transform=transform,
260-
)
261-
print(len(dataset))
262-
data, target = dataset[0]
263-
print(data.shape)
264-
print(target)
111+
return data, label

0 commit comments

Comments
 (0)