Skip to content

Commit 34539b3

Browse files
committed
load_data now splits the data, downloads data and returns all splits
1 parent a9e2cad commit 34539b3

File tree

2 files changed

+65
-41
lines changed

2 files changed

+65
-41
lines changed

main.py

Lines changed: 5 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import numpy as np
44
import torch as th
55
import torch.nn as nn
6+
import wandb
67
from torch.utils.data import DataLoader
78
from torchvision import transforms
89
from tqdm import tqdm
910

10-
import wandb
1111
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1212

1313

@@ -32,42 +32,20 @@ def main():
3232
device = args.device
3333

3434
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
35-
augmentations = transforms.Compose(
35+
transform = transforms.Compose(
3636
[
3737
transforms.Resize((16, 16)),
3838
transforms.ToTensor(),
3939
]
4040
)
4141
else:
42-
augmentations = transforms.Compose([transforms.ToTensor()])
42+
transform = transforms.Compose([transforms.ToTensor()])
4343

44-
# Dataset
45-
assert (
46-
args.validation_split_percentage < 1.0 and args.validation_split_percentage > 0
47-
), "Validation split should be in interval (0,1)"
48-
traindata = load_data(
49-
args.dataset,
50-
split="train",
51-
split_percentage=args.validation_split_percentage,
52-
data_path=args.datafolder,
53-
download=args.download_data,
54-
transform=augmentations,
55-
)
56-
validata = load_data(
57-
args.dataset,
58-
split="validation",
59-
split_percentage=args.validation_split_percentage,
60-
data_path=args.datafolder,
61-
download=args.download_data,
62-
transform=augmentations,
63-
)
64-
testdata = load_data(
44+
traindata, validata, testdata = load_data(
6545
args.dataset,
66-
split="test",
67-
split_percentage=args.validation_split_percentage,
6846
data_path=args.datafolder,
47+
transform=transform,
6948
download=args.download_data,
70-
transform=augmentations,
7149
)
7250

7351
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)

utils/load_data.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,20 @@
1-
from torch.utils.data import Dataset
1+
from torch.utils.data import Dataset, random_split
22

3-
from .dataloaders import MNISTDataset0_3, USPSDataset0_6, USPSH5_Digit_7_9_Dataset
3+
from .dataloaders import (
4+
Downloader,
5+
MNISTDataset0_3,
6+
USPSDataset0_6,
7+
USPSH5_Digit_7_9_Dataset,
8+
)
49

510

6-
def load_data(dataset: str, *args, **kwargs) -> Dataset:
11+
def filter_labels(samples: list, wanted_labels: list) -> list:
12+
return list(filter(lambda x: x in wanted_labels, samples))
13+
14+
15+
def load_data(dataset: str, *args, **kwargs) -> tuple:
716
"""
8-
Load the dataset based on the dataset name.
17+
load the dataset based on the dataset name.
918
1019
Args
1120
----
@@ -18,8 +27,8 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
1827
1928
Returns
2029
-------
21-
dataset : torch.utils.data.Dataset
22-
Dataset object.
30+
tuple
31+
Tuple of train, validation and test datasets.
2332
2433
Raises
2534
------
@@ -28,17 +37,54 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
2837
2938
Examples
3039
--------
31-
>>> from utils import load_data
32-
>>> dataset = load_data("usps_0-6", data_path="data", train=True, download=True)
33-
>>> len(dataset)
34-
5460
40+
>>> from utils import setup_data
41+
>>> train, val, test = setup_data("usps_0-6", data_path="data", train=True, download=True)
42+
>>> len(train), len(val), len(test)
43+
(4914, 546, 1782)
3544
"""
45+
3646
match dataset.lower():
3747
case "usps_0-6":
38-
return USPSDataset0_6(*args, **kwargs)
39-
case "mnist_0-3":
40-
return MNISTDataset0_3(*args, **kwargs)
48+
dataset = USPSDataset0_6
49+
train_samples, test_samples = Downloader.usps(*args, **kwargs)
50+
labels = range(7)
4151
case "usps_7-9":
42-
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
52+
dataset = USPSH5_Digit_7_9_Dataset
53+
train_samples, test_samples = Downloader.usps(*args, **kwargs)
54+
labels = range(7, 10)
55+
case "mnist_0-3":
56+
dataset = MNISTDataset0_3
57+
train_samples, test_samples = Downloader.mnist(*args, **kwargs)
58+
labels = range(4)
4359
case _:
4460
raise NotImplementedError(f"Dataset: {dataset} not implemented.")
61+
62+
val_size = kwargs.get("val_size", 0.1)
63+
64+
train_samples = filter_labels(train_samples, labels)
65+
test_samples = filter_labels(test_samples, labels)
66+
67+
train_samples, val_samples = random_split(train_samples, [1 - val_size, val_size])
68+
69+
train = dataset(
70+
*args,
71+
sample_ids=train_samples,
72+
train=True,
73+
**kwargs,
74+
)
75+
76+
val = dataset(
77+
*args,
78+
sample_ids=val_samples,
79+
train=True,
80+
**kwargs,
81+
)
82+
83+
test = dataset(
84+
*args,
85+
sample_ids=test_samples,
86+
train=False,
87+
**kwargs,
88+
)
89+
90+
return train, val, test

0 commit comments

Comments
 (0)