Skip to content

Commit a20b338

Browse files
committed
Merge branch 'main' of github.com:SFI-Visual-Intelligence/Collaborative-Coding-Exam into mag-branch
2 parents 1a3561a + 23c09e2 commit a20b338

File tree

10 files changed

+76
-64
lines changed

10 files changed

+76
-64
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ dependencies:
1919
- ruff
2020
- scalene
2121
- tqdm
22+
- scipy
2223
- pip:
2324
- torch
2425
- torchvision

main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import numpy as np
44
import torch as th
55
import torch.nn as nn
6-
import wandb
76
from torch.utils.data import DataLoader
87
from torchvision import transforms
98
from tqdm import tqdm
109

10+
import wandb
1111
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1212

1313

@@ -27,7 +27,6 @@ def main():
2727

2828
args = get_args()
2929

30-
3130
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
3231

3332
device = args.device

utils/arg_parser.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,20 @@ def get_args():
4444
"--modelname",
4545
type=str,
4646
default="MagnusModel",
47-
choices=["MagnusModel", "ChristianModel", "SolveigModel", "JanModel"],
47+
choices=[
48+
"MagnusModel",
49+
"ChristianModel",
50+
"SolveigModel",
51+
"JanModel",
52+
"JohanModel",
53+
],
4854
help="Model which to be trained on",
4955
)
5056
parser.add_argument(
5157
"--dataset",
5258
type=str,
5359
default="svhn",
54-
choices=["svhn", "usps_0-6", "uspsh5_7_9", "mnist_0-3"],
60+
choices=["svhn", "usps_0-6", "usps_7-9", "mnist_0-3", "mnist_4-9"],
5561
help="Which dataset to train the model on.",
5662
)
5763

@@ -95,4 +101,10 @@ def get_args():
95101
action="store_true",
96102
help="If true, the code will not run the training loop.",
97103
)
98-
return parser.parse_args()
104+
args = parser.parse_args()
105+
106+
assert args.epoch > 0, "Epoch should be a positive integer."
107+
assert args.learning_rate > 0, "Learning rate should be a positive float."
108+
assert args.batchsize > 0, "Batch size should be a positive integer."
109+
110+
return args

utils/dataloaders/svhn.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
from torch.utils.data import Dataset
1+
import os
2+
23
from scipy.io import loadmat
3-
import os
4+
from torch.utils.data import Dataset
45
from torchvision.datasets import SVHN
56

7+
68
class SVHNDataset(Dataset):
7-
def __init__(self,
8-
datapath: str,
9-
transforms=None,
10-
download_data=True,
11-
split='train'):
9+
def __init__(
10+
self, datapath: str, transforms=None, download_data=True, split="train"
11+
):
1212
"""
1313
Initializes the SVHNDataset object.
1414
Args:
@@ -20,36 +20,38 @@ def __init__(self,
2020
AssertionError: If the split is not 'train' or 'test'.
2121
"""
2222
super().__init__()
23-
assert split == 'train' or split == 'test'
24-
23+
assert split == "train" or split == "test"
24+
2525
if download_data:
2626
self._download_data(datapath, split)
27-
28-
data = loadmat(os.path.join(datapath, f'{split}_32x32.mat'))
29-
27+
28+
data = loadmat(os.path.join(datapath, f"{split}_32x32.mat"))
29+
3030
# Images on the form N x H x W x C
31-
self.images = data['X'].transpose(3, 1, 0, 2)
32-
self.labels = data['y'].flatten()
31+
self.images = data["X"].transpose(3, 1, 0, 2)
32+
self.labels = data["y"].flatten()
3333
self.labels[self.labels == 10] = 0
34-
34+
3535
self.transforms = transforms
36+
3637
def _download_data(self, path: str, split: str):
3738
"""
3839
Downloads the SVHN dataset.
3940
Args:
4041
path (str): The directory where the dataset will be downloaded.
4142
split (str): The dataset split to download, either 'train' or 'test'.
4243
"""
43-
print(f'Downloading SVHN data into {path}')
44-
SVHN(path, split=split, download=True)
45-
44+
print(f"Downloading SVHN data into {path}")
45+
SVHN(path, split=split, download=True)
46+
4647
def __len__(self):
4748
"""
4849
Returns the number of samples in the dataset.
4950
Returns:
5051
int: The number of samples.
5152
"""
5253
return len(self.labels)
54+
5355
def __getitem__(self, index):
5456
"""
5557
Retrieves the image and label at the specified index.
@@ -59,8 +61,8 @@ def __getitem__(self, index):
5961
tuple: A tuple containing the image and its corresponding label.
6062
"""
6163
img, lab = self.images[index], self.labels[index]
62-
64+
6365
if self.transforms is not None:
6466
img = self.transforms(img)
65-
66-
return img, lab
67+
68+
return img, lab

utils/load_data.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,5 +40,9 @@ def load_data(dataset: str, *args, **kwargs) -> Dataset:
4040
return MNISTDataset0_3(*args, **kwargs)
4141
case "usps_7-9":
4242
return USPSH5_Digit_7_9_Dataset(*args, **kwargs)
43+
case "svhn":
44+
raise NotImplementedError("SVHN dataset not yet implemented.")
45+
case "mnist_4-9":
46+
raise NotImplementedError("MNIST 4-9 dataset not yet implemented.")
4347
case _:
4448
raise NotImplementedError(f"Dataset: {dataset} not implemented.")

utils/load_metric.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88

99
class MetricWrapper(nn.Module):
10-
1110
"""
1211
Wrapper class for metrics, that runs multiple metrics on the same data.
1312
@@ -46,9 +45,7 @@ class MetricWrapper(nn.Module):
4645
{'entropy': [], 'f1': [], 'precision': []}
4746
"""
4847

49-
5048
def __init__(self, *metrics, num_classes):
51-
5249
super().__init__()
5350
self.metrics = {}
5451
self.num_classes = num_classes

utils/load_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch.nn as nn
22

3-
from .models import ChristianModel, JanModel, MagnusModel, SolveigModel
3+
from .models import ChristianModel, JanModel, JohanModel, MagnusModel, SolveigModel
44

55

66
def load_model(modelname: str, *args, **kwargs) -> nn.Module:
@@ -44,6 +44,8 @@ def load_model(modelname: str, *args, **kwargs) -> nn.Module:
4444
return JanModel(*args, **kwargs)
4545
case "solveigmodel":
4646
return SolveigModel(*args, **kwargs)
47+
case "johanmodel":
48+
return JohanModel(*args, **kwargs)
4749
case _:
4850
errmsg = (
4951
f"Model: {modelname} not implemented. "

utils/metrics/EntropyPred.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
1-
import numpy as np
21
import torch.nn as nn
32
from scipy.stats import entropy
43

54

65
class EntropyPrediction(nn.Module):
7-
def __init__(self, averages: str = 'average'):
6+
def __init__(self, averages: str = "average"):
87
"""
98
Initializes the EntropyPrediction module.
109
Args:
11-
averages (str): Specifies the method of aggregation for entropy values.
10+
averages (str): Specifies the method of aggregation for entropy values.
1211
Must be either 'average' or 'sum'.
1312
Raises:
1413
AssertionError: If the averages parameter is not 'average' or 'sum'.
1514
"""
1615
super().__init__()
17-
18-
assert averages == 'average' or averages == 'sum'
16+
17+
assert averages == "average" or averages == "sum"
1918
self.averages = averages
2019
self.stored_entropy_values = []
21-
20+
2221
def __call__(self, y_true, y_false_logits):
2322
"""
2423
Computes the entropy between true labels and predicted logits, storing the results.
@@ -29,4 +28,4 @@ def __call__(self, y_true, y_false_logits):
2928
Appends the computed entropy values to the stored_entropy_values list.
3029
"""
3130
entropy_values = entropy(y_true, qk=y_false_logits)
32-
return entropy_values
31+
return entropy_values

utils/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
__all__ = ["MagnusModel", "ChristianModel", "JanModel", "SolveigModel"]
1+
__all__ = ["MagnusModel", "ChristianModel", "JanModel", "SolveigModel", "JohanModel"]
22

33
from .christian_model import ChristianModel
44
from .jan_model import JanModel
5+
from .johan_model import JohanModel
56
from .magnus_model import MagnusModel
67
from .solveig_model import SolveigModel

utils/models/magnus_model.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,38 @@
22

33

44
class MagnusModel(nn.Module):
5-
def __init__(self,
6-
imagesize: int,
7-
imagechannels: int,
8-
n_classes:int=10):
9-
5+
def __init__(self, imagesize: int, imagechannels: int, n_classes: int = 10):
106
"""
11-
Magnus model contains the model for Magnus' part of the homeexam.
7+
Magnus model contains the model for Magnus' part of the homeexam.
128
This class contains a neural network consisting of three linear layers of 133 neurons each,
139
with ReLU activation between each layer.
1410
1511
Args
1612
----
1713
imagesize (int): Expected size of input image. This is needed to scale first layer input
1814
imagechannels (int): Expected number of image channels. This is needed to scale first layer input
19-
n_classes (int): Number of classes we are to provide.
15+
n_classes (int): Number of classes we are to provide.
2016
2117
Returns
2218
-------
2319
MagnusModel (nn.Module): Neural network as described above in this docstring.
2420
"""
25-
2621
super().__init__()
27-
self.imagesize = imagesize
22+
self.imagesize = imagesize
2823
self.imagechannels = imagechannels
29-
30-
self.layer1 = nn.Sequential(*([
31-
nn.Linear(self.imagechannels*self.imagesize*self.imagesize, 133),
32-
nn.ReLU()
33-
]))
34-
self.layer2 = nn.Sequential(*([
35-
nn.Linear(133, 133),
36-
nn.ReLU()
37-
]))
38-
self.layer3 = nn.Sequential(*([
39-
nn.Linear(133, n_classes),
40-
nn.ReLU()
41-
]))
24+
25+
self.layer1 = nn.Sequential(
26+
*(
27+
[
28+
nn.Linear(
29+
self.imagechannels * self.imagesize * self.imagesize, 133
30+
),
31+
nn.ReLU(),
32+
]
33+
)
34+
)
35+
self.layer2 = nn.Sequential(*([nn.Linear(133, 133), nn.ReLU()]))
36+
self.layer3 = nn.Sequential(*([nn.Linear(133, n_classes), nn.ReLU()]))
4237

4338
def forward(self, x):
4439
"""
@@ -47,17 +42,17 @@ def forward(self, x):
4742
Args
4843
----
4944
x (th.Tensor): Four-dimensional tensor in the form (Batch Size x Channels x Image Height x Image Width)
50-
45+
5146
Returns
5247
-------
5348
out (th.Tensor): Class-logits of network given input x
5449
"""
5550
assert len(x.size) == 4
56-
51+
5752
x = x.view(x.size(0), -1)
58-
53+
5954
x = self.layer1(x)
6055
x = self.layer2(x)
6156
out = self.layer3(x)
62-
57+
6358
return out

0 commit comments

Comments
 (0)