Skip to content

Commit f2e14c4

Browse files
authored
Merge pull request #60 from SFI-Visual-Intelligence/christian/train-val-split
Implementing @salomaestro s changes to the downloading process.
2 parents efc78f3 + 601caca commit f2e14c4

File tree

10 files changed

+334
-323
lines changed

10 files changed

+334
-323
lines changed

main.py

Lines changed: 7 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
from pathlib import Path
2-
31
import numpy as np
42
import torch as th
53
import torch.nn as nn
4+
import wandb
65
from torch.utils.data import DataLoader
76
from torchvision import transforms
87
from tqdm import tqdm
98

10-
import wandb
119
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1210

1311

@@ -32,42 +30,20 @@ def main():
3230
device = args.device
3331

3432
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
35-
augmentations = transforms.Compose(
33+
transform = transforms.Compose(
3634
[
3735
transforms.Resize((16, 16)),
3836
transforms.ToTensor(),
3937
]
4038
)
4139
else:
42-
augmentations = transforms.Compose([transforms.ToTensor()])
40+
transform = transforms.Compose([transforms.ToTensor()])
4341

44-
# Dataset
45-
assert (
46-
args.validation_split_percentage < 1.0 and args.validation_split_percentage > 0
47-
), "Validation split should be in interval (0,1)"
48-
traindata = load_data(
49-
args.dataset,
50-
split="train",
51-
split_percentage=args.validation_split_percentage,
52-
data_path=args.datafolder,
53-
download=args.download_data,
54-
transform=augmentations,
55-
)
56-
validata = load_data(
57-
args.dataset,
58-
split="validation",
59-
split_percentage=args.validation_split_percentage,
60-
data_path=args.datafolder,
61-
download=args.download_data,
62-
transform=augmentations,
63-
)
64-
testdata = load_data(
42+
traindata, validata, testdata = load_data(
6543
args.dataset,
66-
split="test",
67-
split_percentage=args.validation_split_percentage,
68-
data_path=args.datafolder,
69-
download=args.download_data,
70-
transform=augmentations,
44+
data_dir=args.datafolder,
45+
transform=transform,
46+
val_size=args.val_size,
7147
)
7248

7349
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)

tests/test_dataloaders.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,25 @@ def test_uspsdataset0_6():
1717

1818
# Create a h5 file
1919
with h5py.File(tf, "w") as f:
20+
targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
21+
indices = np.arange(len(targets))
2022
# Populate the file with data
2123
f["train/data"] = np.random.rand(10, 16 * 16)
22-
f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
24+
f["train/target"] = targets
2325

2426
trans = transforms.Compose(
2527
[
26-
transforms.Resize((16, 16)), # At least for USPS
28+
transforms.Resize((16, 16)),
2729
transforms.ToTensor(),
2830
]
2931
)
30-
dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans)
32+
dataset = USPSDataset0_6(
33+
data_path=tempdir,
34+
sample_ids=indices,
35+
train=True,
36+
transform=trans,
37+
)
3138
assert len(dataset) == 10
3239
data, target = dataset[0]
3340
assert data.shape == (1, 16, 16)
34-
assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))
41+
assert target == 6

tests/test_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,3 @@ def test_jan_model(image_shape, num_classes):
3232
y = model(x)
3333

3434
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
35-

utils/arg_parser.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,6 @@ def get_args():
3333
help="Whether model should be saved or not.",
3434
)
3535

36-
parser.add_argument(
37-
"--download-data",
38-
action="store_true",
39-
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
40-
)
41-
4236
# Data/Model specific values
4337
parser.add_argument(
4438
"--modelname",
@@ -55,7 +49,7 @@ def get_args():
5549
help="Which dataset to train the model on.",
5650
)
5751
parser.add_argument(
58-
"--validation_split_percentage",
52+
"--val_size",
5953
type=float,
6054
default=0.2,
6155
help="Percentage of training dataset to be used as validation dataset - must be within (0,1).",

utils/dataloaders/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
1-
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3"]
1+
__all__ = [
2+
"USPSDataset0_6",
3+
"USPSH5_Digit_7_9_Dataset",
4+
"MNISTDataset0_3",
5+
"Downloader",
6+
]
27

8+
from .download import Downloader
39
from .mnist_0_3 import MNISTDataset0_3
410
from .usps_0_6 import USPSDataset0_6
511
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset

utils/dataloaders/datasources.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,26 @@
1717
"8ea070ee2aca1ac39742fdd1ef5ed118",
1818
],
1919
}
20+
21+
MNIST_SOURCE = {
22+
"train_images": [
23+
"https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
24+
"train-images-idx3-ubyte",
25+
None,
26+
],
27+
"train_labels": [
28+
"https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
29+
"train-labels-idx1-ubyte",
30+
None,
31+
],
32+
"test_images": [
33+
"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
34+
"t10k-images-idx3-ubyte",
35+
None,
36+
],
37+
"test_labels": [
38+
"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
39+
"t10k-labels-idx1-ubyte",
40+
None,
41+
],
42+
}

utils/dataloaders/download.py

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
import bz2
2+
import gzip
3+
import hashlib
4+
import os
5+
from pathlib import Path
6+
from tempfile import TemporaryDirectory
7+
from urllib.request import urlretrieve
8+
9+
import h5py as h5
10+
import numpy as np
11+
12+
from .datasources import MNIST_SOURCE, USPS_SOURCE
13+
14+
15+
class Downloader:
16+
"""
17+
Class to download and load the USPS dataset.
18+
19+
Methods
20+
-------
21+
mnist(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
22+
Download the MNIST dataset and save it as an HDF5 file to `data_dir`.
23+
svhn(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
24+
Download the SVHN dataset and save it as an HDF5 file to `data_dir`.
25+
usps(data_dir: Path) -> tuple[np.ndarray, np.ndarray]
26+
Download the USPS dataset and save it as an HDF5 file to `data_dir`.
27+
28+
Raises
29+
------
30+
NotImplementedError
31+
If the download method is not implemented for the dataset.
32+
33+
Examples
34+
--------
35+
>>> from pathlib import Path
36+
>>> from utils import Downloader
37+
>>> dir = Path('tmp')
38+
>>> dir.mkdir(exist_ok=True)
39+
>>> train, test = Downloader().usps(dir)
40+
"""
41+
42+
def mnist(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
43+
def _chech_is_downloaded(path: Path) -> bool:
44+
path = path / "MNIST"
45+
if path.exists():
46+
required_files = [MNIST_SOURCE[key][1] for key in MNIST_SOURCE.keys()]
47+
if all([(path / file).exists() for file in required_files]):
48+
print("MNIST Dataset already downloaded.")
49+
return True
50+
else:
51+
return False
52+
else:
53+
path.mkdir(parents=True, exist_ok=True)
54+
return False
55+
56+
def _download_data(path: Path) -> None:
57+
urls = {key: MNIST_SOURCE[key][0] for key in MNIST_SOURCE.keys()}
58+
59+
for name, url in urls.items():
60+
file_path = os.path.join(path, url.split("/")[-1])
61+
if not os.path.exists(
62+
file_path.replace(".gz", "")
63+
): # Avoid re-downloading
64+
urlretrieve(url, file_path)
65+
with gzip.open(file_path, "rb") as f_in:
66+
with open(file_path.replace(".gz", ""), "wb") as f_out:
67+
f_out.write(f_in.read())
68+
os.remove(file_path) # Remove compressed file
69+
70+
def _get_labels(path: Path) -> np.ndarray:
71+
with open(path, "rb") as f:
72+
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
73+
return data
74+
75+
if not _chech_is_downloaded(data_dir):
76+
_download_data(data_dir)
77+
78+
train_labels_path = data_dir / "MNIST" / MNIST_SOURCE["train_labels"][1]
79+
test_labels_path = data_dir / "MNIST" / MNIST_SOURCE["test_labels"][1]
80+
81+
train_labels = _get_labels(train_labels_path)
82+
test_labels = _get_labels(test_labels_path)
83+
84+
return train_labels, test_labels
85+
86+
def svhn(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
87+
raise NotImplementedError("SVHN download not implemented yet")
88+
89+
def usps(self, data_dir: Path) -> tuple[np.ndarray, np.ndarray]:
90+
"""
91+
Download the USPS dataset and save it as an HDF5 file to `data_dir/usps.h5`.
92+
"""
93+
94+
def already_downloaded(path):
95+
if not path.exists() or not path.is_file():
96+
return False
97+
98+
with h5.File(path, "r") as f:
99+
return "train" in f and "test" in f
100+
101+
filename = data_dir / "usps.h5"
102+
103+
if already_downloaded(filename):
104+
with h5.File(filename, "r") as f:
105+
return f["train"]["target"][:], f["test"]["target"][:]
106+
107+
url_train, _, train_md5 = USPS_SOURCE["train"]
108+
url_test, _, test_md5 = USPS_SOURCE["test"]
109+
110+
# Using temporary directory ensures temporary files are deleted after use
111+
with TemporaryDirectory() as tmp_dir:
112+
train_path = Path(tmp_dir) / "train"
113+
test_path = Path(tmp_dir) / "test"
114+
115+
# Download the dataset and report the progress
116+
urlretrieve(url_train, train_path, reporthook=self.__reporthook)
117+
self.__check_integrity(train_path, train_md5)
118+
train_targets = self.__extract_usps(train_path, filename, "train")
119+
120+
urlretrieve(url_test, test_path, reporthook=self.__reporthook)
121+
self.__check_integrity(test_path, test_md5)
122+
test_targets = self.__extract_usps(test_path, filename, "test")
123+
124+
return train_targets, test_targets
125+
126+
def __extract_usps(self, src: Path, dest: Path, mode: str):
127+
# Load the dataset and save it as an HDF5 file
128+
with bz2.open(src) as fp:
129+
raw = [line.decode().split() for line in fp.readlines()]
130+
131+
tmp = [[x.split(":")[-1] for x in data[1:]] for data in raw]
132+
133+
imgs = np.asarray(tmp, dtype=np.float32)
134+
imgs = ((imgs + 1) / 2 * 255).astype(dtype=np.uint8)
135+
136+
targets = [int(d[0]) - 1 for d in raw]
137+
138+
with h5.File(dest, "a") as f:
139+
f.create_dataset(f"{mode}/data", data=imgs, dtype=np.float32)
140+
f.create_dataset(f"{mode}/target", data=targets, dtype=np.int32)
141+
142+
return targets
143+
144+
@staticmethod
145+
def __reporthook(blocknum, blocksize, totalsize):
146+
"""
147+
Use this function to report download progress
148+
for the urllib.request.urlretrieve function.
149+
"""
150+
151+
denom = 1024 * 1024
152+
readsofar = blocknum * blocksize
153+
154+
if totalsize > 0:
155+
percent = readsofar * 1e2 / totalsize
156+
s = f"\r{int(percent):^3}% {readsofar / denom:.2f} of {totalsize / denom:.2f} MB"
157+
print(s, end="", flush=True)
158+
if readsofar >= totalsize:
159+
print()
160+
161+
@staticmethod
162+
def __check_integrity(filepath, checksum):
163+
"""Check the integrity of the USPS dataset file.
164+
165+
Args
166+
----
167+
filepath : pathlib.Path
168+
Path to the USPS dataset file.
169+
checksum : str
170+
MD5 checksum of the dataset file.
171+
172+
Returns
173+
-------
174+
bool
175+
True if the checksum of the file matches the expected checksum, False otherwise
176+
"""
177+
178+
file_hash = hashlib.md5(filepath.read_bytes()).hexdigest()
179+
180+
if not checksum == file_hash:
181+
raise ValueError(
182+
f"File integrity check failed. Expected {checksum}, got {file_hash}"
183+
)

0 commit comments

Comments
 (0)