Skip to content

Commit 4071181

Browse files
authored
Merge branch 'Jan-dataloader' into Jan-metrics
2 parents 177258b + b7bffa3 commit 4071181

File tree

18 files changed

+537
-357
lines changed

18 files changed

+537
-357
lines changed

.gitignore

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
__pycache__/
22
.ipynb_checkpoints/
3-
Data/
4-
Results/
5-
Experiments/
3+
Data/*
4+
Results/*
5+
Experiments/*
66
_build/
7-
bin/
8-
wandb/
7+
bin/*
8+
wandb/*
99
wandb_api.py
1010

1111
#Magnus specific
1212
docker/*
13+
job*
1314

1415
# Byte-compiled / optimized / DLL files
1516
__pycache__/

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- ruff
2020
- scalene
2121
- tqdm
22+
- scipy
2223
- pip:
2324
- torch
2425
- torchvision

main.py

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
from pathlib import Path
2-
31
import numpy as np
42
import torch as th
53
import torch.nn as nn
6-
import wandb
74
from torch.utils.data import DataLoader
85
from torchvision import transforms
96
from tqdm import tqdm
107

8+
import wandb
119
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1210

1311

@@ -27,35 +25,25 @@ def main():
2725

2826
args = get_args()
2927

30-
3128
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
3229

3330
device = args.device
3431

3532
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
36-
augmentations = transforms.Compose(
33+
transform = transforms.Compose(
3734
[
3835
transforms.Resize((16, 16)),
3936
transforms.ToTensor(),
4037
]
4138
)
4239
else:
43-
augmentations = transforms.Compose([transforms.ToTensor()])
40+
transform = transforms.Compose([transforms.ToTensor()])
4441

45-
# Dataset
46-
traindata = load_data(
47-
args.dataset,
48-
train=True,
49-
data_path=args.datafolder,
50-
download=args.download_data,
51-
transform=augmentations,
52-
)
53-
validata = load_data(
42+
traindata, validata, testdata = load_data(
5443
args.dataset,
55-
train=False,
56-
data_path=args.datafolder,
57-
download=args.download_data,
58-
transform=augmentations,
44+
data_dir=args.datafolder,
45+
transform=transform,
46+
val_size=args.val_size,
5947
)
6048

6149
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes, macro_averaging=args.macro_averaging)
@@ -83,6 +71,9 @@ def main():
8371
valiloader = DataLoader(
8472
validata, batch_size=args.batchsize, shuffle=False, pin_memory=True
8573
)
74+
testloader = DataLoader(
75+
testdata, batch_size=args.batchsize, shuffle=False, pin_memory=True
76+
)
8677

8778
criterion = nn.CrossEntropyLoss()
8879
optimizer = th.optim.Adam(model.parameters(), lr=args.learning_rate)
@@ -107,18 +98,22 @@ def main():
10798
optimizer.step()
10899
optimizer.zero_grad(set_to_none=True)
109100

110-
preds = th.argmax(logits, dim=1)
111-
metrics(y, preds)
101+
metrics(y, logits)
112102

113103
break
114104
print(metrics.accumulate())
115105
print("Dry run completed successfully.")
116-
exit(0)
117-
118-
wandb.login(key=WANDB_API)
119-
wandb.init(entity="ColabCode", project="Jan", tags=[args.modelname, args.dataset])
106+
exit()
107+
108+
# wandb.login(key=WANDB_API)
109+
wandb.init(
110+
entity="ColabCode-org",
111+
# entity="FYS-8805 Exam",
112+
project="Test",
113+
tags=[args.modelname, args.dataset]
114+
)
120115
wandb.watch(model)
121-
116+
exit()
122117
for epoch in range(args.epoch):
123118
# Training loop start
124119
trainingloss = []
@@ -134,36 +129,50 @@ def main():
134129
optimizer.zero_grad(set_to_none=True)
135130
trainingloss.append(loss.item())
136131

137-
preds = th.argmax(logits, dim=1)
138-
metrics(y, preds)
132+
metrics(y, logits)
139133

140134
wandb.log(metrics.accumulate(str_prefix="Train "))
141135
metrics.reset()
142136

143-
evalloss = []
144-
# Eval loop start
137+
valloss = []
138+
# Validation loop start
145139
model.eval()
146140
with th.no_grad():
147141
for x, y in tqdm(valiloader, desc="Validation"):
148142
x, y = x.to(device), y.to(device)
149143
logits = model.forward(x)
150144
loss = criterion(logits, y)
151-
evalloss.append(loss.item())
145+
valloss.append(loss.item())
152146

153-
preds = th.argmax(logits, dim=1)
154-
metrics(y, preds)
147+
metrics(y, logits)
155148

156-
wandb.log(metrics.accumulate(str_prefix="Evaluation "))
149+
wandb.log(metrics.accumulate(str_prefix="Validation "))
157150
metrics.reset()
158151

159152
wandb.log(
160153
{
161154
"Epoch": epoch,
162155
"Train loss": np.mean(trainingloss),
163-
"Evaluation Loss": np.mean(evalloss),
156+
"Validation loss": np.mean(valloss),
164157
}
165158
)
166159

160+
testloss = []
161+
model.eval()
162+
with th.no_grad():
163+
for x, y in tqdm(testloader, desc="Testing"):
164+
x, y = x.to(device), y.to(device)
165+
logits = model.forward(x)
166+
loss = criterion(logits, y)
167+
testloss.append(loss.item())
168+
169+
preds = th.argmax(logits, dim=1)
170+
metrics(y, preds)
171+
172+
wandb.log(metrics.accumulate(str_prefix="Test "))
173+
metrics.reset()
174+
wandb.log({"Test loss": np.mean(testloss)})
175+
167176

168177
if __name__ == "__main__":
169178
main()

tests/test_dataloaders.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,25 @@ def test_uspsdataset0_6():
1717

1818
# Create a h5 file
1919
with h5py.File(tf, "w") as f:
20+
targets = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
21+
indices = np.arange(len(targets))
2022
# Populate the file with data
2123
f["train/data"] = np.random.rand(10, 16 * 16)
22-
f["train/target"] = np.array([6, 5, 4, 3, 2, 1, 0, 0, 0, 0])
24+
f["train/target"] = targets
2325

2426
trans = transforms.Compose(
2527
[
26-
transforms.Resize((16, 16)), # At least for USPS
28+
transforms.Resize((16, 16)),
2729
transforms.ToTensor(),
2830
]
2931
)
30-
dataset = USPSDataset0_6(data_path=tempdir, train=True, transform=trans)
32+
dataset = USPSDataset0_6(
33+
data_path=tempdir,
34+
sample_ids=indices,
35+
train=True,
36+
transform=trans,
37+
)
3138
assert len(dataset) == 10
3239
data, target = dataset[0]
3340
assert data.shape == (1, 16, 16)
34-
assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))
41+
assert target == 6

tests/test_models.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,3 @@ def test_jan_model(image_shape, num_classes):
3232
y = model(x)
3333

3434
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
35-

utils/arg_parser.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,28 +33,33 @@ def get_args():
3333
help="Whether model should be saved or not.",
3434
)
3535

36-
parser.add_argument(
37-
"--download-data",
38-
action="store_true",
39-
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
40-
)
41-
4236
# Data/Model specific values
4337
parser.add_argument(
4438
"--modelname",
4539
type=str,
4640
default="MagnusModel",
47-
choices=["MagnusModel", "ChristianModel", "SolveigModel", "JanModel"],
41+
choices=[
42+
"MagnusModel",
43+
"ChristianModel",
44+
"SolveigModel",
45+
"JanModel",
46+
"JohanModel",
47+
],
4848
help="Model which to be trained on",
4949
)
5050
parser.add_argument(
5151
"--dataset",
5252
type=str,
5353
default="svhn",
54-
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
54+
choices=["svhn", "usps_0-6", "usps_7-9", "mnist_0-3", "mnist_4-9"],
5555
help="Which dataset to train the model on.",
5656
)
57-
57+
parser.add_argument(
58+
"--val_size",
59+
type=float,
60+
default=0.2,
61+
help="Percentage of training dataset to be used as validation dataset - must be within (0,1).",
62+
)
5863
parser.add_argument(
5964
"--metric",
6065
type=str,
@@ -70,6 +75,16 @@ def get_args():
7075
)
7176

7277

78+
parser.add_argument("--imagesize", type=int, default=28, help="Imagesize")
79+
80+
parser.add_argument(
81+
"--nr_channels",
82+
type=int,
83+
default=1,
84+
choices=[1, 3],
85+
help="Number of image channels",
86+
)
87+
7388
# Training specific values
7489
parser.add_argument(
7590
"--epoch",
@@ -101,4 +116,10 @@ def get_args():
101116
action="store_true",
102117
help="If the flag is included, the code will not run the training loop.",
103118
)
104-
return parser.parse_args()
119+
args = parser.parse_args()
120+
121+
assert args.epoch > 0, "Epoch should be a positive integer."
122+
assert args.learning_rate > 0, "Learning rate should be a positive float."
123+
assert args.batchsize > 0, "Batch size should be a positive integer."
124+
125+
return args

utils/dataloaders/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
1-
__all__ = ["USPSDataset0_6", "USPSH5_Digit_7_9_Dataset", "MNISTDataset0_3"]
1+
__all__ = [
2+
"USPSDataset0_6",
3+
"USPSH5_Digit_7_9_Dataset",
4+
"MNISTDataset0_3",
5+
"Downloader",
6+
"SVHNDataset",
7+
]
28

9+
from .download import Downloader
310
from .mnist_0_3 import MNISTDataset0_3
11+
from .svhn import SVHNDataset
412
from .usps_0_6 import USPSDataset0_6
513
from .uspsh5_7_9 import USPSH5_Digit_7_9_Dataset

utils/dataloaders/datasources.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,26 @@
1717
"8ea070ee2aca1ac39742fdd1ef5ed118",
1818
],
1919
}
20+
21+
MNIST_SOURCE = {
22+
"train_images": [
23+
"https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz",
24+
"train-images-idx3-ubyte",
25+
None,
26+
],
27+
"train_labels": [
28+
"https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz",
29+
"train-labels-idx1-ubyte",
30+
None,
31+
],
32+
"test_images": [
33+
"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz",
34+
"t10k-images-idx3-ubyte",
35+
None,
36+
],
37+
"test_labels": [
38+
"https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz",
39+
"t10k-labels-idx1-ubyte",
40+
None,
41+
],
42+
}

0 commit comments

Comments
 (0)