Skip to content

Commit f840af4

Browse files
committed
ruffed and isorted
1 parent 9baa17e commit f840af4

File tree

5 files changed

+41
-43
lines changed

5 files changed

+41
-43
lines changed

main.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import numpy as np
44
import torch as th
55
import torch.nn as nn
6-
import wandb
76
from torch.utils.data import DataLoader
87
from torchvision import transforms
98
from tqdm import tqdm
109

10+
import wandb
1111
from utils import MetricWrapper, createfolders, get_args, load_data, load_model
1212

1313

@@ -27,7 +27,6 @@ def main():
2727

2828
args = get_args()
2929

30-
3130
createfolders(args.datafolder, args.resultfolder, args.modelfolder)
3231

3332
device = args.device
@@ -43,7 +42,9 @@ def main():
4342
augmentations = transforms.Compose([transforms.ToTensor()])
4443

4544
# Dataset
46-
assert args.validation_split_percentage < 1.0 and args.validation_split_percentage > 0, "Validation split should be in interval (0,1)"
45+
assert (
46+
args.validation_split_percentage < 1.0 and args.validation_split_percentage > 0
47+
), "Validation split should be in interval (0,1)"
4748
traindata = load_data(
4849
args.dataset,
4950
split="train",
@@ -177,7 +178,7 @@ def main():
177178
"Validation loss": np.mean(valloss),
178179
}
179180
)
180-
181+
181182
testloss = []
182183
model.eval()
183184
with th.no_grad():
@@ -186,13 +187,14 @@ def main():
186187
logits = model.forward(x)
187188
loss = criterion(logits, y)
188189
testloss.append(loss.item())
189-
190+
190191
preds = th.argmax(logits, dim=1)
191192
metrics(y, preds)
192193

193194
wandb.log(metrics.accumulate(str_prefix="Test "))
194195
metrics.reset()
195196
wandb.log({"Test loss": np.mean(testloss)})
196197

198+
197199
if __name__ == "__main__":
198200
main()

utils/dataloaders/mnist_0_3.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import os
33
import urllib.request
44
from pathlib import Path
5-
import torch
65

76
import numpy as np
7+
import torch
88
from torch.utils.data import Dataset, random_split
99

1010

@@ -77,8 +77,8 @@ def __init__(
7777
self.num_classes = 4
7878

7979
if self.split == "train" or self.split == "validation":
80-
train = True # used to decide whether to load training or test dataset
81-
80+
train = True # used to decide whether to load training or test dataset
81+
8282
if not self.download and not self._chech_is_downloaded():
8383
raise ValueError(
8484
"Data not found. Set --download-data=True to download the data."
@@ -94,14 +94,18 @@ def __init__(
9494
)
9595

9696
labels = self._parse_labels()
97-
97+
9898
self.idx = np.where(labels < 4)[0]
99-
99+
100100
if self.split != "test":
101101
generator1 = torch.Generator().manual_seed(42)
102-
tr, val = random_split(self.idx, [1-self.split_percentage, self.split_percentage], generator=generator1)
102+
tr, val = random_split(
103+
self.idx,
104+
[1 - self.split_percentage, self.split_percentage],
105+
generator=generator1,
106+
)
103107
self.idx = tr if self.split == "train" else val
104-
108+
105109
self.length = len(self.idx)
106110

107111
def _parse_labels(self):

utils/load_metric.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88

99
class MetricWrapper(nn.Module):
10-
1110
"""
1211
Wrapper class for metrics, that runs multiple metrics on the same data.
1312
@@ -46,9 +45,7 @@ class MetricWrapper(nn.Module):
4645
{'entropy': [], 'f1': [], 'precision': []}
4746
"""
4847

49-
5048
def __init__(self, *metrics, num_classes):
51-
5249
super().__init__()
5350
self.metrics = {}
5451
self.num_classes = num_classes

utils/metrics/EntropyPred.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,4 @@ def __call__(self, y_true, y_false_logits):
99
return
1010

1111
def __reset__(self):
12-
pass
12+
pass

utils/models/magnus_model.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,44 +2,39 @@
22

33

44
class MagnusModel(nn.Module):
5-
def __init__(self,
6-
imagesize: int,
7-
imagechannels: int,
8-
n_classes:int=10):
9-
5+
def __init__(self, imagesize: int, imagechannels: int, n_classes: int = 10):
106
"""
11-
Magnus model contains the model for Magnus' part of the homeexam.
7+
Magnus model contains the model for Magnus' part of the homeexam.
128
This class contains a neural network consisting of three linear layers of 133 neurons each,
139
with ReLU activation between each layer.
1410
1511
Args
1612
----
1713
imagesize (int): Expected size of input image. This is needed to scale first layer input
1814
imagechannels (int): Expected number of image channels. This is needed to scale first layer input
19-
n_classes (int): Number of classes we are to provide.
15+
n_classes (int): Number of classes we are to provide.
2016
2117
Returns
2218
-------
2319
MagnusModel (nn.Module): Neural network as described above in this docstring.
2420
"""
25-
26-
21+
2722
super().__init__()
28-
self.imagesize = imagesize
23+
self.imagesize = imagesize
2924
self.imagechannels = imagechannels
30-
31-
self.layer1 = nn.Sequential(*([
32-
nn.Linear(self.imagechannels*self.imagesize*self.imagesize, 133),
33-
nn.ReLU()
34-
]))
35-
self.layer2 = nn.Sequential(*([
36-
nn.Linear(133, 133),
37-
nn.ReLU()
38-
]))
39-
self.layer3 = nn.Sequential(*([
40-
nn.Linear(133, n_classes),
41-
nn.ReLU()
42-
]))
25+
26+
self.layer1 = nn.Sequential(
27+
*(
28+
[
29+
nn.Linear(
30+
self.imagechannels * self.imagesize * self.imagesize, 133
31+
),
32+
nn.ReLU(),
33+
]
34+
)
35+
)
36+
self.layer2 = nn.Sequential(*([nn.Linear(133, 133), nn.ReLU()]))
37+
self.layer3 = nn.Sequential(*([nn.Linear(133, n_classes), nn.ReLU()]))
4338

4439
def forward(self, x):
4540
"""
@@ -48,17 +43,17 @@ def forward(self, x):
4843
Args
4944
----
5045
x (th.Tensor): Four-dimensional tensor in the form (Batch Size x Channels x Image Height x Image Width)
51-
46+
5247
Returns
5348
-------
5449
out (th.Tensor): Class-logits of network given input x
5550
"""
5651
assert len(x.size) == 4
57-
52+
5853
x = x.view(x.size(0), -1)
59-
54+
6055
x = self.layer1(x)
6156
x = self.layer2(x)
6257
out = self.layer3(x)
63-
58+
6459
return out

0 commit comments

Comments
 (0)