Skip to content

Commit 4ab5bd7

Browse files
authored
Merge branch 'main' into christian/sphinx-autoapi
2 parents 2e202c9 + 891f09b commit 4ab5bd7

File tree

14 files changed

+287
-187
lines changed

14 files changed

+287
-187
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Results/
55
Experiments/
66
_build/
77
bin/
8+
wandb/
9+
wandb_api.py
810

911
# Byte-compiled / optimized / DLL files
1012
__pycache__/

main.py

Lines changed: 38 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import argparse
21
from pathlib import Path
32

43
import numpy as np
@@ -9,7 +8,7 @@
98
from torchvision import transforms
109
from tqdm import tqdm
1110

12-
from utils import MetricWrapper, createfolders, load_data, load_model
11+
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1312

1413

1514
def main():
@@ -25,113 +24,21 @@ def main():
2524
------
2625
2726
"""
28-
parser = argparse.ArgumentParser(
29-
prog="",
30-
description="",
31-
epilog="",
32-
)
33-
# Structuture related values
34-
parser.add_argument(
35-
"--datafolder",
36-
type=Path,
37-
default="Data",
38-
help="Path to where data will be saved during training.",
39-
)
40-
parser.add_argument(
41-
"--resultfolder",
42-
type=Path,
43-
default="Results",
44-
help="Path to where results will be saved during evaluation.",
45-
)
46-
parser.add_argument(
47-
"--modelfolder",
48-
type=Path,
49-
default="Experiments",
50-
help="Path to where model weights will be saved at the end of training.",
51-
)
52-
parser.add_argument(
53-
"--savemodel",
54-
action="store_true",
55-
help="Whether model should be saved or not.",
56-
)
57-
58-
parser.add_argument(
59-
"--download-data",
60-
action="store_true",
61-
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
62-
)
63-
64-
# Data/Model specific values
65-
parser.add_argument(
66-
"--modelname",
67-
type=str,
68-
default="MagnusModel",
69-
choices=["MagnusModel", "ChristianModel", "SolveigModel"],
70-
help="Model which to be trained on",
71-
)
72-
parser.add_argument(
73-
"--dataset",
74-
type=str,
75-
default="svhn",
76-
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
77-
help="Which dataset to train the model on.",
78-
)
79-
80-
parser.add_argument(
81-
"--metric",
82-
type=str,
83-
default=["entropy"],
84-
choices=["entropy", "f1", "recall", "precision", "accuracy"],
85-
nargs="+",
86-
help="Which metric to use for evaluation",
87-
)
88-
89-
# Training specific values
90-
parser.add_argument(
91-
"--epoch",
92-
type=int,
93-
default=20,
94-
help="Amount of training epochs the model will do.",
95-
)
96-
parser.add_argument(
97-
"--learning_rate",
98-
type=float,
99-
default=0.001,
100-
help="Learning rate parameter for model training.",
101-
)
102-
parser.add_argument(
103-
"--batchsize",
104-
type=int,
105-
default=64,
106-
help="Amount of training images loaded in one go",
107-
)
108-
parser.add_argument(
109-
"--device",
110-
type=str,
111-
default="cpu",
112-
choices=["cuda", "cpu", "mps"],
113-
help="Which device to run the training on.",
114-
)
115-
parser.add_argument(
116-
"--dry_run",
117-
action="store_true",
118-
help="If true, the code will not run the training loop.",
119-
)
120-
121-
args = parser.parse_args()
27+
args = get_args()
12228

12329
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
12430

12531
device = args.device
12632

127-
metrics = MetricWrapper(*args.metric)
128-
129-
augmentations = transforms.Compose(
130-
[
131-
transforms.Resize((16, 16)), # At least for USPS
132-
transforms.ToTensor(),
133-
]
134-
)
33+
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
34+
augmentations = transforms.Compose(
35+
[
36+
transforms.Resize((16, 16)),
37+
transforms.ToTensor(),
38+
]
39+
)
40+
else:
41+
augmentations = transforms.Compose([transforms.ToTensor()])
13542

13643
# Dataset
13744
traindata = load_data(
@@ -149,6 +56,8 @@ def main():
14956
transform=augmentations,
15057
)
15158

59+
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
60+
15261
# Find the shape of the data, if is 2D, add a channel dimension
15362
data_shape = traindata[0][0].shape
15463
if len(data_shape) == 2:
@@ -180,28 +89,32 @@ def main():
18089
if args.dry_run:
18190
dry_run_loader = DataLoader(
18291
traindata,
183-
batch_size=1,
92+
batch_size=20,
18493
shuffle=True,
18594
pin_memory=True,
18695
drop_last=True,
18796
)
18897

18998
for x, y in tqdm(dry_run_loader, desc="Dry run", total=1):
19099
x, y = x.to(device), y.to(device)
191-
pred = model.forward(x)
100+
logits = model.forward(x)
192101

193-
loss = criterion(y, pred)
102+
loss = criterion(logits, y)
194103
loss.backward()
195104

196105
optimizer.step()
197106
optimizer.zero_grad(set_to_none=True)
198107

199-
break
108+
preds = th.argmax(logits, dim=1)
109+
metrics(y, preds)
200110

111+
break
112+
print(metrics.__getmetrics__())
201113
print("Dry run completed successfully.")
202114
exit(0)
203115

204-
wandb.init(project="", tags=[])
116+
wandb.login(key=WANDB_API)
117+
wandb.init(entity="ColabCode", project="Jan", tags=[args.modelname, args.dataset])
205118
wandb.watch(model)
206119

207120
for epoch in range(args.epoch):
@@ -210,25 +123,37 @@ def main():
210123
model.train()
211124
for x, y in tqdm(trainloader, desc="Training"):
212125
x, y = x.to(device), y.to(device)
213-
pred = model.forward(x)
126+
logits = model.forward(x)
214127

215-
loss = criterion(y, pred)
128+
loss = criterion(logits, y)
216129
loss.backward()
217130

218131
optimizer.step()
219132
optimizer.zero_grad(set_to_none=True)
220133
trainingloss.append(loss.item())
221134

135+
preds = th.argmax(logits, dim=1)
136+
metrics(y, preds)
137+
138+
wandb.log(metrics.__getmetrics__(str_prefix="Train "))
139+
metrics.__resetvalues__()
140+
222141
evalloss = []
223142
# Eval loop start
224143
model.eval()
225144
with th.no_grad():
226145
for x, y in tqdm(valiloader, desc="Validation"):
227146
x, y = x.to(device), y.to(device)
228-
pred = model.forward(x)
229-
loss = criterion(y, pred)
147+
logits = model.forward(x)
148+
loss = criterion(logits, y)
230149
evalloss.append(loss.item())
231150

151+
preds = th.argmax(logits, dim=1)
152+
metrics(y, preds)
153+
154+
wandb.log(metrics.__getmetrics__(str_prefix="Evaluation "))
155+
metrics.__resetvalues__()
156+
232157
wandb.log(
233158
{
234159
"Epoch": epoch,

tests/test_metrics.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from utils.metrics import F1Score, Recall
1+
from utils.metrics import Accuracy, F1Score, Precision, Recall
22

33

44
def test_recall():
@@ -30,3 +30,70 @@ def test_f1score():
3030
assert f1_metric.tp.sum().item() > 0, "Expected some true positives."
3131
assert f1_metric.fp.sum().item() > 0, "Expected some false positives."
3232
assert f1_metric.fn.sum().item() > 0, "Expected some false negatives."
33+
34+
35+
def test_precision_case1():
36+
import torch
37+
38+
for boolean, true_precision in zip([True, False], [25.0 / 36, 7.0 / 10]):
39+
true1 = torch.tensor([0, 1, 2, 1, 0, 2, 1, 0, 2, 1])
40+
pred1 = torch.tensor([0, 2, 1, 1, 0, 2, 0, 0, 2, 1])
41+
P = Precision(3, use_mean=boolean)
42+
precision1 = P(true1, pred1)
43+
assert precision1.allclose(torch.tensor(true_precision), atol=1e-5), (
44+
f"Precision Score: {precision1.item()}"
45+
)
46+
47+
48+
def test_precision_case2():
49+
import torch
50+
51+
for boolean, true_precision in zip([True, False], [8.0 / 15, 6.0 / 15]):
52+
true2 = torch.tensor([0, 1, 2, 3, 4, 0, 1, 2, 3, 4, 0, 1, 2, 3, 4])
53+
pred2 = torch.tensor([0, 0, 4, 3, 4, 0, 4, 4, 2, 3, 4, 1, 2, 4, 0])
54+
P = Precision(5, use_mean=boolean)
55+
precision2 = P(true2, pred2)
56+
assert precision2.allclose(torch.tensor(true_precision), atol=1e-5), (
57+
f"Precision Score: {precision2.item()}"
58+
)
59+
60+
61+
def test_precision_case3():
62+
import torch
63+
64+
for boolean, true_precision in zip([True, False], [3.0 / 4, 4.0 / 5]):
65+
true3 = torch.tensor([0, 0, 0, 1, 0])
66+
pred3 = torch.tensor([1, 0, 0, 1, 0])
67+
P = Precision(2, use_mean=boolean)
68+
precision3 = P(true3, pred3)
69+
assert precision3.allclose(torch.tensor(true_precision), atol=1e-5), (
70+
f"Precision Score: {precision3.item()}"
71+
)
72+
73+
74+
def test_for_zero_denominator():
75+
import torch
76+
77+
for boolean in [True, False]:
78+
true4 = torch.tensor([1, 1, 1, 1, 1])
79+
pred4 = torch.tensor([0, 0, 0, 0, 0])
80+
P = Precision(2, use_mean=boolean)
81+
precision4 = P(true4, pred4)
82+
assert precision4.allclose(torch.tensor(0.0), atol=1e-5), (
83+
f"Precision Score: {precision4.item()}"
84+
)
85+
86+
87+
def test_accuracy():
88+
import torch
89+
90+
accuracy = Accuracy(num_classes=5)
91+
92+
y_true = torch.tensor([0, 3, 2, 3, 4])
93+
y_pred = torch.tensor([0, 1, 2, 3, 4])
94+
95+
accuracy_score = accuracy(y_true, y_pred)
96+
97+
assert torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5, (
98+
f"Accuracy Score: {accuracy_score.item()}"
99+
)

tests/test_models.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
from utils.models import ChristianModel
4+
from utils.models import ChristianModel, JanModel
55

66

77
@pytest.mark.parametrize(
@@ -17,6 +17,19 @@ def test_christian_model(image_shape, num_classes):
1717
y = model(x)
1818

1919
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
20-
assert y.sum(dim=1).allclose(torch.ones(n), atol=1e-5), (
21-
f"Softmax output should sum to 1, but got: {y.sum()}"
22-
)
20+
21+
22+
@pytest.mark.parametrize(
23+
"image_shape, num_classes",
24+
[((1, 28, 28), 4), ((3, 16, 16), 10)],
25+
)
26+
def test_jan_model(image_shape, num_classes):
27+
n, c, h, w = 5, *image_shape
28+
29+
model = JanModel(image_shape, num_classes)
30+
31+
x = torch.randn(n, c, h, w)
32+
y = model(x)
33+
34+
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
35+

utils/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
__all__ = ["createfolders", "load_data", "load_model", "MetricWrapper"]
1+
__all__ = ["createfolders", "load_data", "load_model", "MetricWrapper", "get_args"]
22

3+
from .arg_parser import get_args
34
from .createfolders import createfolders
45
from .load_data import load_data
56
from .load_metric import MetricWrapper

0 commit comments

Comments
 (0)