Skip to content

Commit 5b22af0

Browse files
committed
Added several tests, undid some changes to the argparser file
1 parent 75b1801 commit 5b22af0

File tree

9 files changed

+158
-63
lines changed

9 files changed

+158
-63
lines changed

main.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@ def main():
2222
------
2323
2424
"""
25-
25+
2626
args = get_args()
27-
27+
2828
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
29-
29+
3030
device = args.device
31-
31+
3232
if args.dataset.lower() in ["usps_0-6", "uspsh5_7_9"]:
3333
augmentations = transforms.Compose(
3434
[
@@ -38,7 +38,7 @@ def main():
3838
)
3939
else:
4040
augmentations = transforms.Compose([transforms.ToTensor()])
41-
41+
4242
# Dataset
4343
traindata = load_data(
4444
args.dataset,
@@ -54,22 +54,22 @@ def main():
5454
download=args.download_data,
5555
transform=augmentations,
5656
)
57-
58-
metrics = MetricWrapper(*args.metric, num_classes=traindata.num_classes)
59-
57+
58+
metrics = MetricWrapper(traindata.num_classes, *args.metric)
59+
6060
# Find the shape of the data, if is 2D, add a channel dimension
6161
data_shape = traindata[0][0].shape
6262
if len(data_shape) == 2:
6363
data_shape = (1, *data_shape)
64-
64+
6565
# load model
6666
model = load_model(
6767
args.modelname,
6868
image_shape=data_shape,
6969
num_classes=traindata.num_classes,
7070
)
7171
model.to(device)
72-
72+
7373
trainloader = DataLoader(
7474
traindata,
7575
batch_size=args.batchsize,

tests/test_dataloaders.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
from utils.dataloaders.usps_0_6 import USPSDataset0_6
2-
1+
from utils.dataloaders import USPSDataset0_6, SVHNDataset
32

43
def test_uspsdataset0_6():
54
from pathlib import Path
@@ -32,3 +31,33 @@ def test_uspsdataset0_6():
3231
data, target = dataset[0]
3332
assert data.shape == (1, 16, 16)
3433
assert all(target == np.array([0, 0, 0, 0, 0, 0, 1]))
34+
35+
36+
def test_svhn_dataset():
37+
import os
38+
from tempfile import TemporaryDirectory
39+
from torchvision import transforms
40+
41+
with TemporaryDirectory() as tempdir:
42+
43+
trans = transforms.Compose([
44+
transforms.Resize((28,28)),
45+
transforms.ToTensor()
46+
])
47+
48+
dataset = SVHNDataset(tempdir,
49+
train=True,
50+
transform=trans,
51+
download=True,
52+
nr_channels=1)
53+
54+
assert dataset.__len__() != 0
55+
assert os.path.exists(os.path.join(tempdir, 'train_32x32.mat'))
56+
57+
img, label = dataset.__getitem__(0)
58+
assert len(img.size()) == 3 and img.size() == (1,28,28) and img.size(0) == 1
59+
assert len(label.size()) == 1
60+
61+
62+
if __name__ == '__main__':
63+
test_svhn_dataset()

tests/test_metrics.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from utils.metrics import Accuracy, F1Score, Precision, Recall
1+
from utils.metrics import Accuracy, F1Score, Precision, Recall, EntropyPrediction
22

33

44
def test_recall():
@@ -97,3 +97,21 @@ def test_accuracy():
9797
assert torch.abs(torch.tensor(accuracy_score - 0.8)) < 1e-5, (
9898
f"Accuracy Score: {accuracy_score.item()}"
9999
)
100+
101+
def test_entropypred():
102+
import torch as th
103+
104+
metric = EntropyPrediction(averages='mean')
105+
106+
true_lab = th.Tensor([0,1,1,2,4,3]).reshape(6,1).type(th.LongTensor)
107+
pred_logits = th.nn.functional.one_hot(true_lab, 5)
108+
109+
#Test for log(0) errors and expected output
110+
assert th.abs((th.sum(metric(true_lab, pred_logits)) - 0.0)) < 1e-5
111+
112+
pred_logits = th.rand(6,5)
113+
metric2 = EntropyPrediction(averages='sum')
114+
115+
#Test for averaging metric consistency
116+
assert th.abs(th.sum(6*metric(true_lab, pred_logits) - metric2(true_lab, pred_logits))) < 1e-5
117+

tests/test_models.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22
import torch
33

4-
from utils.models import ChristianModel, JanModel
4+
from utils.models import ChristianModel, JanModel, MagnusModel
55

66

77
@pytest.mark.parametrize(
@@ -33,3 +33,20 @@ def test_jan_model(image_shape, num_classes):
3333

3434
assert y.shape == (n, num_classes), f"Shape: {y.shape}"
3535

36+
37+
@pytest.mark.parameterize(
38+
"image_shape",
39+
[(3,28,28)]
40+
)
41+
def test_magnus_model(image_shape):
42+
import torch as th
43+
44+
n, c, h, w = 5, *image_shape
45+
model = MagnusModel([h,w], 10, c)
46+
47+
x = th.rand((n, c, h, w))
48+
with th.no_grad():
49+
y = model(x)
50+
51+
assert y.shape == (n, 10), f"Shape: {y.shape}"
52+

utils/arg_parser.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,7 @@ def get_args():
3535

3636
parser.add_argument(
3737
"--download-data",
38-
type=bool,
39-
default=False,
38+
action="store_true",
4039
help="Whether the data should be downloaded or not. Might cause code to start a bit slowly.",
4140
)
4241

utils/dataloaders/svhn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(
4040

4141
self.nr_channels = nr_channels
4242
self.transforms = transform
43+
self.num_classes = len(np.unique(self.labels))
4344

4445
def _download_data(self, path: str):
4546
"""

utils/load_metric.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import copy
2-
32
import numpy as np
43
import torch.nn as nn
5-
64
from .metrics import Accuracy, EntropyPrediction, F1Score, Precision, Recall
75

86

@@ -45,7 +43,7 @@ class MetricWrapper(nn.Module):
4543
{'entropy': [], 'f1': [], 'precision': []}
4644
"""
4745

48-
def __init__(self, *metrics, num_classes):
46+
def __init__(self, num_classes, *metrics):
4947
super().__init__()
5048
self.metrics = {}
5149
self.num_classes = num_classes
@@ -72,7 +70,8 @@ def _get_metric(self, key):
7270

7371
match key.lower():
7472
case "entropy":
75-
return EntropyPrediction(num_classes=self.num_classes)
73+
#Not dependent on knowing the number of classes
74+
return EntropyPrediction()
7675
case "f1":
7776
return F1Score(num_classes=self.num_classes)
7877
case "recall":

utils/metrics/EntropyPred.py

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,54 @@ def __init__(self, averages: str = "average"):
88
Initializes the EntropyPrediction module.
99
Args:
1010
averages (str): Specifies the method of aggregation for entropy values.
11-
Must be either 'average' or 'sum'.
11+
Must be either 'mean', 'sum' or 'none.
1212
Raises:
13-
AssertionError: If the averages parameter is not 'average' or 'sum'.
13+
AssertionError: If the averages parameter is not 'mean' or 'sum'.
1414
"""
1515
super().__init__()
1616

17-
assert averages == "average" or averages == "sum"
17+
assert averages == "mean" or averages == "sum"
1818
self.averages = averages
1919
self.stored_entropy_values = []
2020

21-
def __call__(self, y_true, y_false_logits):
21+
def __call__(self, y_true, y_logits):
2222
"""
23-
Computes the entropy between true labels and predicted logits, storing the results.
23+
Computes the Shannon Entropy of the predicted logits, storing the results.
2424
Args:
25-
y_true: The true labels.
26-
y_false_logits: The predicted logits.
27-
Side Effects:
28-
Appends the computed entropy values to the stored_entropy_values list.
25+
y_true: The true labels. Does nothing, but needed for compatability sake.
26+
y_logits: The predicted logits.
2927
"""
30-
entropy_values = entropy(y_true, qk=y_false_logits)
28+
entropy_values = entropy(y_logits, axis=1)
29+
entropy_values = th.from_numpy(entropy_values)
30+
31+
#Fix numerical errors for perfect guesses
32+
entropy_values[entropy_values == th.inf] = 0
33+
entropy_values = th.nan_to_num(entropy_values)
34+
35+
36+
if self.averages == 'mean':
37+
entropy_values = th.mean(entropy_values)
38+
39+
elif self.averages == 'sum':
40+
entropy_values = th.sum(entropy_values)
41+
42+
elif self.averages == 'none':
43+
return entropy_values
44+
45+
3146
return entropy_values
47+
48+
49+
if __name__ == '__main__':
50+
import torch as th
51+
52+
metric = EntropyPrediction(averages='mean')
53+
54+
true_lab = th.Tensor([0,1,1,2,4,3]).reshape(6,1)
55+
pred_logits = th.nn.functional.one_hot(true_lab, 5)
56+
57+
assert th.abs((th.sum(metric(true_lab, pred_logits)) - 0.0)) < 1e-5
58+
59+
pred_logits = th.rand(6,5)
60+
metric2 = EntropyPrediction(averages='sum')
61+
assert th.abs(th.sum(6*metric(true_lab, pred_logits) - metric2(true_lab, pred_logits))) < 1e-5

utils/models/magnus_model.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,25 @@
22

33

44
class MagnusModel(nn.Module):
5-
def __init__(self, image_shape: int, num_classes: int, imagechannels: int):
5+
def __init__(self, image_shape, num_classes: int, nr_channels: int):
66
"""
7-
Magnus model contains the model for Magnus' part of the homeexam.
8-
This class contains a neural network consisting of three linear layers of 133 neurons each,
9-
with ReLU activation between each layer.
10-
11-
Args
12-
----
13-
image_shape (int): Expected size of input image. This is needed to scale first layer input
14-
imagechannels (int): Expected number of image channels. This is needed to scale first layer input
15-
num_classes (int): Number of classes we are to provide.
16-
17-
Returns
18-
-------
19-
MagnusModel (nn.Module): Neural network as described above in this docstring.
7+
Initializes the MagnusModel, a neural network designed for image classification tasks.
8+
9+
The model consists of three linear layers, each with 133 neurons, and uses ReLU activation
10+
functions between the layers. The first layer's input size is determined by the image shape
11+
and number of channels, while the output layer's size is determined by the number of classes.
12+
Args:
13+
image_shape (tuple): A tuple representing the dimensions of the input image (Channels, Height, Width).
14+
num_classes (int): The number of output classes for classification.
15+
nr_channels (int): The number of channels in the input image.
16+
Returns:
17+
MagnusModel (nn.Module): An instance of the MagnusModel neural network.
2018
"""
2119
super().__init__()
22-
self.image_shape = image_shape
23-
self.imagechannels = imagechannels
24-
20+
_, H, W = image_shape
21+
2522
self.layer1 = nn.Sequential(*([
26-
nn.Linear(self.imagechannels * self.imagesize * self.imagesize, 133),
23+
nn.Linear(nr_channels * H * W, 133),
2724
nn.ReLU(),
2825
]))
2926
self.layer2 = nn.Sequential(*([
@@ -32,27 +29,32 @@ def __init__(self, image_shape: int, num_classes: int, imagechannels: int):
3229
]))
3330
self.layer3 = nn.Sequential(*([
3431
nn.Linear(133, num_classes),
35-
nn.ReLU()
3632
]))
37-
3833
def forward(self, x):
3934
"""
40-
Forward pass of MagnusModel
41-
42-
Args
43-
----
44-
x (th.Tensor): Four-dimensional tensor in the form (Batch Size x Channels x Image Height x Image Width)
45-
46-
Returns
47-
-------
48-
out (th.Tensor): Class-logits of network given input x
35+
Defines the forward pass of the MagnusModel.
36+
Args:
37+
x (torch.Tensor): A four-dimensional tensor with shape (Batch Size, Channels, Image Height, Image Width).
38+
Returns:
39+
torch.Tensor: The output tensor containing class logits for each input sample.
4940
"""
50-
assert len(x.size) == 4
51-
41+
assert len(x.size()) == 4
5242
x = x.view(x.size(0), -1)
53-
5443
x = self.layer1(x)
5544
x = self.layer2(x)
5645
out = self.layer3(x)
57-
5846
return out
47+
48+
49+
if __name__ == '__main__':
50+
import torch as th
51+
52+
data_shape = [28,28]
53+
54+
data_shape = (3, *data_shape)
55+
model = MagnusModel(data_shape, 10)
56+
57+
dummy_img = th.rand((5,*data_shape))
58+
print(dummy_img.size())
59+
with th.no_grad():
60+
print(model(dummy_img).size())

0 commit comments

Comments
 (0)