Skip to content

Commit 8d3d7e0

Browse files
committed
Add torchvision transforms and avoid data download if exists
1 parent 3b1a442 commit 8d3d7e0

File tree

1 file changed

+39
-9
lines changed

1 file changed

+39
-9
lines changed

utils/dataloaders/usps_0_6.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,9 @@
1212

1313
import h5py as h5
1414
import numpy as np
15+
from PIL import Image
1516
from torch.utils.data import Dataset
17+
from torchvision import transforms
1618

1719
from .datasources import USPS_SOURCE
1820

@@ -88,18 +90,36 @@ def __init__(
8890

8991
# Download the dataset if it does not exist in a temporary directory
9092
# to automatically clean up the downloaded file
91-
if download:
93+
if download and not self._dataset_ok():
9294
url, _, checksum = USPS_SOURCE[self.mode]
9395

9496
print(f"Downloading USPS dataset ({self.mode})...")
9597
self.download(url, self.filepath, checksum, self.mode)
9698

9799
self.idx = self._index()
98100

101+
def _dataset_ok(self):
102+
"""Check if the dataset file exists and contains the required datasets."""
103+
104+
if not self.filepath.exists():
105+
print(f"Dataset file {self.filepath} does not exist.")
106+
return False
107+
108+
with h5.File(self.filepath, "r") as f:
109+
for mode in ["train", "test"]:
110+
if mode not in f:
111+
print(
112+
f"Dataset file {self.filepath} is missing the {mode} dataset."
113+
)
114+
return False
115+
116+
return True
117+
99118
def download(self, url, filepath, checksum, mode):
100119
"""Download the USPS dataset."""
101120

102121
def reporthook(blocknum, blocksize, totalsize):
122+
"""Report download progress."""
103123
denom = 1024 * 1024
104124
readsofar = blocknum * blocksize
105125
if totalsize > 0:
@@ -109,6 +129,7 @@ def reporthook(blocknum, blocksize, totalsize):
109129
if readsofar >= totalsize:
110130
print()
111131

132+
# Download the dataset to a temporary file
112133
with TemporaryDirectory() as tmpdir:
113134
tmpdir = Path(tmpdir)
114135
tmpfile = tmpdir / "usps.bz2"
@@ -137,7 +158,7 @@ def reporthook(blocknum, blocksize, totalsize):
137158

138159
targets = [int(d[0]) - 1 for d in raw]
139160

140-
with h5.File(self.filepath, "w") as f:
161+
with h5.File(self.filepath, "a") as f:
141162
f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32)
142163
f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32)
143164

@@ -161,7 +182,7 @@ def _index(self):
161182

162183
def _load_data(self, idx):
163184
with h5.File(self.filepath, "r") as f:
164-
data = f[self.mode]["data"][idx]
185+
data = f[self.mode]["data"][idx].astype(np.uint8)
165186
label = f[self.mode]["target"][idx]
166187

167188
return data, label
@@ -171,23 +192,32 @@ def __len__(self):
171192

172193
def __getitem__(self, idx):
173194
data, target = self._load_data(self.idx[idx])
174-
175-
data = data.reshape(16, 16)
195+
data = Image.fromarray(data, mode="L")
176196

177197
# one hot encode the target
178198
target = np.eye(self.num_classes, dtype=np.float32)[target]
179199

180-
# Add channel dimension
181-
data = np.expand_dims(data, axis=0)
182-
183200
if self.transform:
184201
data = self.transform(data)
185202

186203
return data, target
187204

188205

189206
if __name__ == "__main__":
190-
dataset = USPSDataset0_6(data_path="data", train=True, download=True)
207+
# Example usage:
208+
transform = transforms.Compose(
209+
[
210+
transforms.Resize((16, 16)),
211+
transforms.ToTensor(),
212+
]
213+
)
214+
215+
dataset = USPSDataset0_6(
216+
data_path="data",
217+
train=True,
218+
download=False,
219+
transform=transform,
220+
)
191221
print(len(dataset))
192222
data, target = dataset[0]
193223
print(data.shape)

0 commit comments

Comments
 (0)