Skip to content

Commit 3b1a442

Browse files
committed
Add download functionality
1 parent 7ff097a commit 3b1a442

File tree

3 files changed

+105
-10
lines changed

3 files changed

+105
-10
lines changed

tests/test_dataloaders.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,24 @@
33

44
def test_uspsdataset0_6():
55
from pathlib import Path
6-
from tempfile import TemporaryFile
6+
from tempfile import TemporaryDirectory
77

88
import h5py
99
import numpy as np
1010

11-
with TemporaryFile() as tf:
11+
# Create a temporary directory (deleted after the test)
12+
with TemporaryDirectory() as tempdir:
13+
tempdir = Path(tempdir)
14+
15+
tf = tempdir / "usps.h5"
16+
17+
# Create a h5 file
1218
with h5py.File(tf, "w") as f:
19+
# Populate the file with data
1320
f["train/data"] = np.random.rand(10, 16 * 16)
1421
f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
1522

16-
dataset = USPSDataset0_6(data_path=tf, train=True)
23+
dataset = USPSDataset0_6(data_path=tempdir, train=True)
1724
assert len(dataset) == 10
1825
data, target = dataset[0]
1926
assert data.shape == (1, 16, 16)

utils/dataloaders/datasources.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
USPS_SOURCE = {
2+
"train": [
3+
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
4+
"usps.bz2",
5+
"ec16c51db3855ca6c91edd34d0e9b197",
6+
],
7+
"test": [
8+
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
9+
"usps.t.bz2",
10+
"8ea070ee2aca1ac39742fdd1ef5ed118",
11+
],
12+
}

utils/dataloaders/usps_0_6.py

Lines changed: 83 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,18 @@
44
This module contains the Dataset class for the USPS dataset with labels 0-6.
55
"""
66

7+
import bz2
8+
import hashlib
79
from pathlib import Path
10+
from tempfile import TemporaryDirectory, TemporaryFile
11+
from urllib.request import urlretrieve
812

913
import h5py as h5
1014
import numpy as np
1115
from torch.utils.data import Dataset
1216

17+
from .datasources import USPS_SOURCE
18+
1319

1420
class USPSDataset0_6(Dataset):
1521
"""
@@ -28,7 +34,7 @@ class USPSDataset0_6(Dataset):
2834
2935
Attributes
3036
----------
31-
path : pathlib.Path
37+
filepath : pathlib.Path
3238
Path to the USPS dataset file.
3339
mode : str
3440
Mode of the dataset, either train or test.
@@ -63,6 +69,8 @@ class USPSDataset0_6(Dataset):
6369
6
6470
"""
6571

72+
filename = "usps.h5"
73+
6674
def __init__(
6775
self,
6876
data_path: Path,
@@ -71,18 +79,78 @@ def __init__(
7179
download: bool = False,
7280
):
7381
super().__init__()
74-
self.path = data_path
82+
83+
path = data_path if isinstance(data_path, Path) else Path(data_path)
84+
self.filepath = path / self.filename
7585
self.transform = transform
76-
self.num_classes = 7
86+
self.num_classes = 7 # 0-6
87+
self.mode = "train" if train else "test"
7788

89+
# Download the dataset if it does not exist in a temporary directory
90+
# to automatically clean up the downloaded file
7891
if download:
79-
raise NotImplementedError("Download functionality not implemented.")
92+
url, _, checksum = USPS_SOURCE[self.mode]
93+
94+
print(f"Downloading USPS dataset ({self.mode})...")
95+
self.download(url, self.filepath, checksum, self.mode)
8096

81-
self.mode = "train" if train else "test"
8297
self.idx = self._index()
8398

99+
def download(self, url, filepath, checksum, mode):
100+
"""Download the USPS dataset."""
101+
102+
def reporthook(blocknum, blocksize, totalsize):
103+
denom = 1024 * 1024
104+
readsofar = blocknum * blocksize
105+
if totalsize > 0:
106+
percent = readsofar * 1e2 / totalsize
107+
s = f"\r{int(percent):^3}% {readsofar / denom:.2f} of {totalsize / denom:.2f} MB"
108+
print(s, end="", flush=True)
109+
if readsofar >= totalsize:
110+
print()
111+
112+
with TemporaryDirectory() as tmpdir:
113+
tmpdir = Path(tmpdir)
114+
tmpfile = tmpdir / "usps.bz2"
115+
urlretrieve(
116+
url,
117+
tmpfile,
118+
reporthook=reporthook,
119+
)
120+
121+
# For fun we can check the integrity of the downloaded file
122+
if not self.check_integrity(tmpfile, checksum):
123+
errmsg = (
124+
"The checksum of the downloaded file does "
125+
"not match the expected checksum."
126+
)
127+
raise ValueError(errmsg)
128+
129+
# Load the dataset and save it as an HDF5 file
130+
with bz2.open(tmpfile) as fp:
131+
raw = [line.decode().split() for line in fp.readlines()]
132+
133+
tmp = [[x.split(":")[-1] for x in data[1:]] for data in raw]
134+
135+
imgs = np.asarray(tmp, dtype=np.float32)
136+
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
137+
138+
targets = [int(d[0]) - 1 for d in raw]
139+
140+
with h5.File(self.filepath, "w") as f:
141+
f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32)
142+
f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32)
143+
144+
@staticmethod
145+
def check_integrity(filepath, checksum):
146+
"""Check the integrity of the USPS dataset file."""
147+
148+
file_hash = hashlib.md5(filepath.read_bytes()).hexdigest()
149+
150+
return checksum == file_hash
151+
84152
def _index(self):
85-
with h5.File(self.path, "r") as f:
153+
with h5.File(self.filepath, "r") as f:
86154
labels = f[self.mode]["target"][:]
87155

88156
# Get indices of samples with labels 0-6
@@ -92,7 +160,7 @@ def _index(self):
92160
return idx
93161

94162
def _load_data(self, idx):
95-
with h5.File(self.path, "r") as f:
163+
with h5.File(self.filepath, "r") as f:
96164
data = f[self.mode]["data"][idx]
97165
label = f[self.mode]["target"][idx]
98166

@@ -116,3 +184,11 @@ def __getitem__(self, idx):
116184
data = self.transform(data)
117185

118186
return data, target
187+
188+
189+
if __name__ == "__main__":
190+
dataset = USPSDataset0_6(data_path="data", train=True, download=True)
191+
print(len(dataset))
192+
data, target = dataset[0]
193+
print(data.shape)
194+
print(target)

0 commit comments

Comments
 (0)