Skip to content

Commit 9baa17e

Browse files
committed
load_data - changed to accomodate train/val/test split, added test loop
1 parent b93ee66 commit 9baa17e

File tree

3 files changed

+61
-16
lines changed

3 files changed

+61
-16
lines changed

main.py

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,27 @@ def main():
4343
augmentations = transforms.Compose([transforms.ToTensor()])
4444

4545
# Dataset
46+
assert args.validation_split_percentage < 1.0 and args.validation_split_percentage > 0, "Validation split should be in interval (0,1)"
4647
traindata = load_data(
4748
args.dataset,
48-
train=True,
49+
split="train",
50+
split_percentage=args.validation_split_percentage,
4951
data_path=args.datafolder,
5052
download=args.download_data,
5153
transform=augmentations,
5254
)
5355
validata = load_data(
5456
args.dataset,
55-
train=False,
57+
split="validation",
58+
split_percentage=args.validation_split_percentage,
59+
data_path=args.datafolder,
60+
download=args.download_data,
61+
transform=augmentations,
62+
)
63+
testdata = load_data(
64+
args.dataset,
65+
split="test",
66+
split_percentage=args.validation_split_percentage,
5667
data_path=args.datafolder,
5768
download=args.download_data,
5869
transform=augmentations,
@@ -83,6 +94,9 @@ def main():
8394
valiloader = DataLoader(
8495
validata, batch_size=args.batchsize, shuffle=False, pin_memory=True
8596
)
97+
testloader = DataLoader(
98+
testdata, batch_size=args.batchsize, shuffle=False, pin_memory=True
99+
)
86100

87101
criterion = nn.CrossEntropyLoss()
88102
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
@@ -140,30 +154,45 @@ def main():
140154
wandb.log(metrics.accumulate(str_prefix="Train "))
141155
metrics.reset()
142156

143-
evalloss = []
144-
# Eval loop start
157+
valloss = []
158+
# Validation loop start
145159
model.eval()
146160
with th.no_grad():
147161
for x, y in tqdm(valiloader, desc="Validation"):
148162
x, y = x.to(device), y.to(device)
149163
logits = model.forward(x)
150164
loss = criterion(logits, y)
151-
evalloss.append(loss.item())
165+
valloss.append(loss.item())
152166

153167
preds = th.argmax(logits, dim=1)
154168
metrics(y, preds)
155169

156-
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
170+
wandb.log(metrics.accumulate(str_prefix="Validation "))
157171
metrics.reset()
158172

159173
wandb.log(
160174
{
161175
"Epoch": epoch,
162176
"Train loss": np.mean(trainingloss),
163-
"Evaluation Loss": np.mean(evalloss),
177+
"Validation loss": np.mean(valloss),
164178
}
165179
)
180+
181+
testloss = []
182+
model.eval()
183+
with th.no_grad():
184+
for x, y in tqdm(testloader, desc="Testing"):
185+
x, y = x.to(device), y.to(device)
186+
logits = model.forward(x)
187+
loss = criterion(logits, y)
188+
testloss.append(loss.item())
189+
190+
preds = th.argmax(logits, dim=1)
191+
metrics(y, preds)
166192

193+
wandb.log(metrics.accumulate(str_prefix="Test "))
194+
metrics.reset()
195+
wandb.log({"Test loss": np.mean(testloss)})
167196

168197
if __name__ == "__main__":
169198
main()

utils/arg_parser.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,12 @@ def get_args():
5454
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
5555
help="Which dataset to train the model on.",
5656
)
57-
57+
parser.add_argument(
58+
"--validation_split_percentage",
59+
type=float,
60+
default=0.2,
61+
help="Percentage of training dataset to be used as validation dataset - must be within (0,1).",
62+
)
5863
parser.add_argument(
5964
"--metric",
6065
type=str,

utils/dataloaders/mnist_0_3.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,10 @@
22
import os
33
import urllib.request
44
from pathlib import Path
5+
import torch
56

67
import numpy as np
7-
from torch.utils.data import Dataset
8+
from torch.utils.data import Dataset, random_split
89

910

1011
class MNISTDataset0_3(Dataset):
@@ -59,20 +60,25 @@ class MNISTDataset0_3(Dataset):
5960

6061
def __init__(
6162
self,
63+
split: str,
64+
split_percentage: float,
6265
data_path: Path,
63-
train: bool = False,
64-
transform=None,
6566
download: bool = False,
67+
transform=None,
6668
):
6769
super().__init__()
6870

6971
self.data_path = data_path
7072
self.mnist_path = self.data_path / "MNIST"
71-
self.train = train
73+
self.split = split
74+
self.split_percentage = split_percentage
7275
self.transform = transform
7376
self.download = download
7477
self.num_classes = 4
7578

79+
if self.split == "train" or self.split == "validation":
80+
train = True # used to decide whether to load training or test dataset
81+
7682
if not self.download and not self._chech_is_downloaded():
7783
raise ValueError(
7884
"Data not found. Set --download-data=True to download the data."
@@ -87,13 +93,18 @@ def __init__(
8793
"train-labels-idx1-ubyte" if train else "t10k-labels-idx1-ubyte"
8894
)
8995

90-
labels = self._parse_labels(train=self.train)
91-
96+
labels = self._parse_labels()
97+
9298
self.idx = np.where(labels < 4)[0]
93-
99+
100+
if self.split != "test":
101+
generator1 = torch.Generator().manual_seed(42)
102+
tr, val = random_split(self.idx, [1-self.split_percentage, self.split_percentage], generator=generator1)
103+
self.idx = tr if self.split == "train" else val
104+
94105
self.length = len(self.idx)
95106

96-
def _parse_labels(self, train):
107+
def _parse_labels(self):
97108
with open(self.labels_path, "rb") as f:
98109
data = np.frombuffer(f.read(), dtype=np.uint8, offset=8)
99110
return data

0 commit comments

Comments
 (0)