Skip to content

Commit 0b21d9d

Browse files
authored
Merge pull request #33 from SFI-Visual-Intelligence/christian/rework-usps_0-6-dataset
Christian/rework usps 0 6 dataset and add some functionality to main.py
2 parents 07a4ede + d6999aa commit 0b21d9d

File tree

7 files changed

+191
-26
lines changed

7 files changed

+191
-26
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ dependencies:
1818
- pytest
1919
- ruff
2020
- scalene
21+
- tqdm
2122
- pip:
2223
- torch
2324
- torchvision

main.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import torch.nn as nn
77
import wandb
88
from torch.utils.data import DataLoader
9+
from torchvision import transforms
10+
from tqdm import tqdm
911

1012
from utils import MetricWrapper, createfolders, load_data, load_model
1113

@@ -49,15 +51,13 @@ def main():
4951
)
5052
parser.add_argument(
5153
"--savemodel",
52-
type=bool,
53-
default=False,
54+
action="store_true",
5455
help="Whether model should be saved or not.",
5556
)
5657

5758
parser.add_argument(
5859
"--download-data",
59-
type=bool,
60-
default=False,
60+
action="store_true",
6161
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
6262
)
6363

@@ -126,17 +126,27 @@ def main():
126126

127127
metrics = MetricWrapper(*args.metric)
128128

129+
augmentations = transforms.Compose(
130+
[
131+
transforms.Resize((16, 16)), # At least for USPS
132+
transforms.ToTensor(),
133+
]
134+
)
135+
129136
# Dataset
130137
traindata = load_data(
131138
args.dataset,
132139
train=True,
133140
data_path=args.datafolder,
134141
download=args.download_data,
142+
transform=augmentations,
135143
)
136144
validata = load_data(
137145
args.dataset,
138146
train=False,
139147
data_path=args.datafolder,
148+
download=args.download_data,
149+
transform=augmentations,
140150
)
141151

142152
# Find the shape of the data, if is 2D, add a channel dimension
@@ -168,7 +178,27 @@ def main():
168178

169179
# This allows us to load all the components without running the training loop
170180
if args.dry_run:
171-
print("Dry run completed")
181+
dry_run_loader = DataLoader(
182+
traindata,
183+
batch_size=1,
184+
shuffle=True,
185+
pin_memory=True,
186+
drop_last=True,
187+
)
188+
189+
for x, y in tqdm(dry_run_loader, desc="Dry run", total=1):
190+
x, y = x.to(device), y.to(device)
191+
pred = model.forward(x)
192+
193+
loss = criterion(y, pred)
194+
loss.backward()
195+
196+
optimizer.step()
197+
optimizer.zero_grad(set_to_none=True)
198+
199+
break
200+
201+
print("Dry run completed successfully.")
172202
exit(0)
173203

174204
wandb.init(project="", tags=[])
@@ -178,7 +208,7 @@ def main():
178208
# Training loop start
179209
trainingloss = []
180210
model.train()
181-
for x, y in trainloader:
211+
for x, y in tqdm(trainloader, desc="Training"):
182212
x, y = x.to(device), y.to(device)
183213
pred = model.forward(x)
184214

@@ -193,7 +223,7 @@ def main():
193223
# Eval loop start
194224
model.eval()
195225
with th.no_grad():
196-
for x, y in valiloader:
226+
for x, y in tqdm(valiloader, desc="Validation"):
197227
x, y = x.to(device), y.to(device)
198228
pred = model.forward(x)
199229
loss = criterion(y, pred)

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[tool.isort]
2+
profile = "black"
3+
line_length = 88

tests/test_dataloaders.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,31 @@
33

44
def test_uspsdataset0_6():
55
from pathlib import Path
6-
from tempfile import TemporaryFile
6+
from tempfile import TemporaryDirectory
77

88
import h5py
99
import numpy as np
10+
from torchvision import transforms
1011

11-
with TemporaryFile() as tf:
12+
# Create a temporary directory (deleted after the test)
13+
with TemporaryDirectory() as tempdir:
14+
tempdir = Path(tempdir)
15+
16+
tf = tempdir / "usps.h5"
17+
18+
# Create a h5 file
1219
with h5py.File(tf, "w") as f:
20+
# Populate the file with data
1321
f["train/data"] = np.random.rand(10, 16 * 16)
1422
f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
1523

16-
dataset = USPSDataset0_6(data_path=tf, train=True)
24+
trans = transforms.Compose(
25+
[
26+
transforms.Resize((16, 16)), # At least for USPS
27+
transforms.ToTensor(),
28+
]
29+
)
30+
dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans)
1731
assert len(dataset) == 10
1832
data, target = dataset[0]
1933
assert data.shape == (1, 16, 16)

utils/dataloaders/datasources.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
USPS_SOURCE = {
2+
"train": [
3+
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.bz2",
4+
"usps.bz2",
5+
"ec16c51db3855ca6c91edd34d0e9b197",
6+
],
7+
"test": [
8+
"https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/multiclass/usps.t.bz2",
9+
"usps.t.bz2",
10+
"8ea070ee2aca1ac39742fdd1ef5ed118",
11+
],
12+
}

utils/dataloaders/usps_0_6.py

Lines changed: 120 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,19 @@
44
This module contains the Dataset class for the USPS dataset with labels 0-6.
55
"""
66

7+
import bz2
8+
import hashlib
79
from pathlib import Path
10+
from tempfile import TemporaryDirectory
11+
from urllib.request import urlretrieve
812

913
import h5py as h5
1014
import numpy as np
15+
from PIL import Image
1116
from torch.utils.data import Dataset
17+
from torchvision import transforms
18+
19+
from .datasources import USPS_SOURCE
1220

1321

1422
class USPSDataset0_6(Dataset):
@@ -28,7 +36,7 @@ class USPSDataset0_6(Dataset):
2836
2937
Attributes
3038
----------
31-
path : pathlib.Path
39+
filepath : pathlib.Path
3240
Path to the USPS dataset file.
3341
mode : str
3442
Mode of the dataset, either train or test.
@@ -63,6 +71,8 @@ class USPSDataset0_6(Dataset):
6371
6
6472
"""
6573

74+
filename = "usps.h5"
75+
6676
def __init__(
6777
self,
6878
data_path: Path,
@@ -71,18 +81,97 @@ def __init__(
7181
download: bool = False,
7282
):
7383
super().__init__()
74-
self.path = data_path
84+
85+
path = data_path if isinstance(data_path, Path) else Path(data_path)
86+
self.filepath = path / self.filename
7587
self.transform = transform
76-
self.num_classes = 7
88+
self.num_classes = 7 # 0-6
89+
self.mode = "train" if train else "test"
7790

78-
if download:
79-
raise NotImplementedError("Download functionality not implemented.")
91+
# Download the dataset if it does not exist in a temporary directory
92+
# to automatically clean up the downloaded file
93+
if download and not self._dataset_ok():
94+
url, _, checksum = USPS_SOURCE[self.mode]
95+
96+
print(f"Downloading USPS dataset ({self.mode})...")
97+
self.download(url, self.filepath, checksum, self.mode)
8098

81-
self.mode = "train" if train else "test"
8299
self.idx = self._index()
83100

101+
def _dataset_ok(self):
102+
"""Check if the dataset file exists and contains the required datasets."""
103+
104+
if not self.filepath.exists():
105+
print(f"Dataset file {self.filepath} does not exist.")
106+
return False
107+
108+
with h5.File(self.filepath, "r") as f:
109+
for mode in ["train", "test"]:
110+
if mode not in f:
111+
print(
112+
f"Dataset file {self.filepath} is missing the {mode} dataset."
113+
)
114+
return False
115+
116+
return True
117+
118+
def download(self, url, filepath, checksum, mode):
119+
"""Download the USPS dataset."""
120+
121+
def reporthook(blocknum, blocksize, totalsize):
122+
"""Report download progress."""
123+
denom = 1024 * 1024
124+
readsofar = blocknum * blocksize
125+
if totalsize > 0:
126+
percent = readsofar * 1e2 / totalsize
127+
s = f"\r{int(percent):^3}% {readsofar / denom:.2f} of {totalsize / denom:.2f} MB"
128+
print(s, end="", flush=True)
129+
if readsofar >= totalsize:
130+
print()
131+
132+
# Download the dataset to a temporary file
133+
with TemporaryDirectory() as tmpdir:
134+
tmpdir = Path(tmpdir)
135+
tmpfile = tmpdir / "usps.bz2"
136+
urlretrieve(
137+
url,
138+
tmpfile,
139+
reporthook=reporthook,
140+
)
141+
142+
# For fun we can check the integrity of the downloaded file
143+
if not self.check_integrity(tmpfile, checksum):
144+
errmsg = (
145+
"The checksum of the downloaded file does "
146+
"not match the expected checksum."
147+
)
148+
raise ValueError(errmsg)
149+
150+
# Load the dataset and save it as an HDF5 file
151+
with bz2.open(tmpfile) as fp:
152+
raw = [line.decode().split() for line in fp.readlines()]
153+
154+
tmp = [[x.split(":")[-1] for x in data[1:]] for data in raw]
155+
156+
imgs = np.asarray(tmp, dtype=np.float32)
157+
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
158+
159+
targets = [int(d[0]) - 1 for d in raw]
160+
161+
with h5.File(self.filepath, "a") as f:
162+
f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32)
163+
f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32)
164+
165+
@staticmethod
166+
def check_integrity(filepath, checksum):
167+
"""Check the integrity of the USPS dataset file."""
168+
169+
file_hash = hashlib.md5(filepath.read_bytes()).hexdigest()
170+
171+
return checksum == file_hash
172+
84173
def _index(self):
85-
with h5.File(self.path, "r") as f:
174+
with h5.File(self.filepath, "r") as f:
86175
labels = f[self.mode]["target"][:]
87176

88177
# Get indices of samples with labels 0-6
@@ -92,8 +181,8 @@ def _index(self):
92181
return idx
93182

94183
def _load_data(self, idx):
95-
with h5.File(self.path, "r") as f:
96-
data = f[self.mode]["data"][idx]
184+
with h5.File(self.filepath, "r") as f:
185+
data = f[self.mode]["data"][idx].astype(np.uint8)
97186
label = f[self.mode]["target"][idx]
98187

99188
return data, label
@@ -103,16 +192,33 @@ def __len__(self):
103192

104193
def __getitem__(self, idx):
105194
data, target = self._load_data(self.idx[idx])
106-
107-
data = data.reshape(16, 16)
195+
data = Image.fromarray(data, mode="L")
108196

109197
# one hot encode the target
110198
target = np.eye(self.num_classes, dtype=np.float32)[target]
111199

112-
# Add channel dimension
113-
data = np.expand_dims(data, axis=0)
114-
115200
if self.transform:
116201
data = self.transform(data)
117202

118203
return data, target
204+
205+
206+
if __name__ == "__main__":
207+
# Example usage:
208+
transform = transforms.Compose(
209+
[
210+
transforms.Resize((16, 16)),
211+
transforms.ToTensor(),
212+
]
213+
)
214+
215+
dataset = USPSDataset0_6(
216+
data_path="data",
217+
train=True,
218+
download=False,
219+
transform=transform,
220+
)
221+
print(len(dataset))
222+
data, target = dataset[0]
223+
print(data.shape)
224+
print(target)

utils/load_data.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from torch.utils.data import Dataset
22

3-
from .dataloaders import (MNISTDataset0_3, USPSDataset0_6,
4-
USPSH5_Digit_7_9_Dataset)
3+
from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset
54

65

76
def load_data(dataset: str, *args, **kwargs) -> Dataset:

0 commit comments

Comments
 (0)