Skip to content

Commit 9427e0a

Browse files
authored
Merge pull request #16 from SFI-Visual-Intelligence/magnus-branch
Changed metric loader to a wrapper
2 parents 4499aef + 5b5da1f commit 9427e0a

File tree

4 files changed

+193
-42
lines changed

4 files changed

+193
-42
lines changed

main.py

Lines changed: 70 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,10 @@
11
import torch as th
22
import torch.nn as nn
3-
3+
from torch.utils.data import DataLoader
44
import argparse
5-
from utils import load_metric, load_model, createfolders
6-
7-
8-
9-
10-
11-
12-
13-
14-
15-
5+
import wandb
6+
import numpy as np
7+
from utils import MetricWrapper, load_model, load_data, createfolders
168

179

1810
def main():
@@ -38,16 +30,28 @@ def main():
3830
parser.add_argument('--resultfolder', type=str, default='Results/', help='Path to where results will be saved during evaluation.')
3931
parser.add_argument('--modelfolder', type=str, default='Experiments/', help='Path to where model weights will be saved at the end of training.')
4032
parser.add_argument('--savemodel', type=bool, default=False, help='Whether model should be saved or not.')
33+
4134
parser.add_argument('--download-data', type=bool, default=False, help='Whether the data should be downloaded or not. Might cause code to start a bit slowly.')
4235

4336
#Data/Model specific values
4437
parser.add_argument('--modelname', type=str, default='MagnusModel',
4538
choices = ['MagnusModel'], help="Model which to be trained on")
39+
parser.add_argument('--dataset', type=str, default='svhn',
40+
choices=['svhn'], help='Which dataset to train the model on.')
41+
42+
parser.add_argument('--EntropyPrediction', type=bool, default=True, help='Include the Entropy Prediction metric in evaluation')
43+
parser.add_argument('--F1Score', type=bool, default=True, help='Include the F1Score metric in evaluation')
44+
parser.add_argument('--Recall', type=bool, default=True, help='Include the Recall metric in evaluation')
45+
parser.add_argument('--Precision', type=bool, default=True, help='Include the Precision metric in evaluation')
46+
parser.add_argument('--Accuracy', type=bool, default=True, help='Include the Accuracy metric in evaluation')
4647

4748
#Training specific values
4849
parser.add_argument('--epoch', type=int, default=20, help='Amount of training epochs the model will do.')
4950
parser.add_argument('--learning_rate', type=float, default=0.001, help='Learning rate parameter for model training.')
51+
parser.add_argument('--batchsize', type=int, default=64, help='Amount of training images loaded in one go')
52+
5053
args = parser.parse_args()
54+
5155

5256
createfolders(args)
5357

@@ -57,19 +61,68 @@ def main():
5761
model = load_model()
5862
model.to(device)
5963

64+
metrics = MetricWrapper(
65+
EntropyPred = args.EntropyPrediction,
66+
F1Score = args.F1Score,
67+
Recall = args.Recall,
68+
Precision = args.Precision,
69+
Accuracy = args.Accuracy
70+
)
71+
72+
#Dataset
73+
traindata = load_data(args.dataset)
74+
validata = load_data(args.dataset)
75+
76+
trainloader = DataLoader(traindata,
77+
batch_size=args.batchsize,
78+
shuffle=True,
79+
pin_memory=True,
80+
drop_last=True)
81+
valiloader = DataLoader(validata,
82+
batch_size=args.batchsize,
83+
shuffle=False,
84+
pin_memory=True)
6085

6186
criterion = nn.CrossEntropyLoss()
6287
optimizer = th.optim.Adam(model.parameters(), lr = args.learning_rate)
63-
64-
65-
88+
89+
90+
wandb.init(project='',
91+
tags=[])
92+
wandb.watch(model)
93+
6694
for epoch in range(args.epoch):
6795

6896
#Training loop start
97+
trainingloss = []
98+
model.train()
99+
for x, y in traindata:
100+
x, y = x.to(device), y.to(device)
101+
pred = model.forward(x)
102+
103+
loss = criterion(y, pred)
104+
loss.backward()
105+
106+
optimizer.step()
107+
optimizer.zero_grad(set_to_none=True)
108+
trainingloss.append(loss.item())
69109

110+
evalloss = []
70111
#Eval loop start
71-
72-
pass
112+
model.eval()
113+
with th.no_grad():
114+
for x, y in valiloader:
115+
x = x.to(device)
116+
pred = model.forward(x)
117+
loss = criterion(y, pred)
118+
evalloss.append(loss.item())
119+
120+
wandb.log({
121+
'Epoch': epoch,
122+
'Train loss': np.mean(trainingloss),
123+
'Evaluation Loss': np.mean(evalloss)
124+
})
125+
73126

74127

75128
if __name__ == '__main__':

test.ipynb

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 3,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import argparse"
10+
]
11+
},
12+
{
13+
"cell_type": "code",
14+
"execution_count": 4,
15+
"metadata": {},
16+
"outputs": [
17+
{
18+
"name": "stderr",
19+
"output_type": "stream",
20+
"text": [
21+
"usage: [-h] [--datafolder DATAFOLDER]\n",
22+
": error: unrecognized arguments: --f=/home/magnus/.local/share/jupyter/runtime/kernel-v3fc3d3b04bd8d83becf1be5eacf19e7bf46887012.json\n"
23+
]
24+
},
25+
{
26+
"ename": "SystemExit",
27+
"evalue": "2",
28+
"output_type": "error",
29+
"traceback": [
30+
"An exception has occurred, use %tb to see the full traceback.\n",
31+
"\u001b[0;31mSystemExit\u001b[0m\u001b[0;31m:\u001b[0m 2\n"
32+
]
33+
}
34+
],
35+
"source": [
36+
"parser = argparse.ArgumentParser(\n",
37+
" prog='',\n",
38+
" description='',\n",
39+
" epilog='',\n",
40+
" )\n",
41+
"parser.add_argument('--datafolder', type=str, default='Data/', help='Path to where data will be saved during training.')\n",
42+
"args = parser.parse_args()\n"
43+
]
44+
},
45+
{
46+
"cell_type": "code",
47+
"execution_count": null,
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"print(args)"
52+
]
53+
}
54+
],
55+
"metadata": {
56+
"kernelspec": {
57+
"display_name": "env",
58+
"language": "python",
59+
"name": "python3"
60+
},
61+
"language_info": {
62+
"codemirror_mode": {
63+
"name": "ipython",
64+
"version": 3
65+
},
66+
"file_extension": ".py",
67+
"mimetype": "text/x-python",
68+
"name": "python",
69+
"nbconvert_exporter": "python",
70+
"pygments_lexer": "ipython3",
71+
"version": "3.11.5"
72+
}
73+
},
74+
"nbformat": 4,
75+
"nbformat_minor": 2
76+
}

utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .load_metric import load_metric
1+
from .load_metric import MetricWrapper
22
from .load_model import load_model
33
from .load_data import load_data
44
from .createfolders import createfolders

utils/load_metric.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,49 @@
1+
import copy
2+
import numpy as np
13
import torch.nn as nn
24
from metrics import EntropyPrediction
35

4-
def load_metric(metricname:str) -> nn.Module:
5-
'''
6-
This function returns an instance of a class inhereting from nn.Module.
7-
This class returns the given metric given a set of label - prediction pairs.
8-
9-
Parameters
10-
----------
11-
metricname: string
12-
string naming the metric to return.
13-
14-
Returns
15-
-------
16-
Class
17-
Returns an instance of a class inhereting from nn.Module.
18-
19-
Raises
20-
------
21-
ValueError
22-
When the metricname parameter does not match any implemented metric, raise this error along with a descriptive message.
23-
'''
24-
if metricname == 'EntropyPrediction':
25-
return EntropyPrediction()
26-
else:
27-
raise ValueError(f'Metric: {metricname} has not been implemented. \nCheck the documentation for implemented metrics, or check your spelling')
6+
7+
class MetricWrapper(nn.Module):
8+
def __init__(self,
9+
EntropyPred:bool = True,
10+
F1Score:bool = True,
11+
Recall:bool = True,
12+
Precision:bool = True,
13+
Accuracy:bool = True):
14+
super().__init__()
15+
self.metrics = {}
16+
17+
if EntropyPred:
18+
self.metrics['Entropy of Predictions'] = EntropyPrediction()
19+
20+
if F1Score:
21+
self.metrics['F1 Score'] = None
22+
23+
if Recall:
24+
self.metrics['Recall'] = None
25+
26+
if Precision:
27+
self.metrics['Precision'] = None
28+
29+
if Accuracy:
30+
self.metrics['Accuracy'] = None
31+
32+
self.tmp_scores = copy.deepcopy(self.metrics)
33+
for key in self.tmp_scores:
34+
self.tmp_scores[key] = []
35+
36+
def __call__(self, y_true, y_pred):
37+
for key in self.metrics:
38+
self.tmp_scores[key].append(self.metrics[key](y_true, y_pred))
39+
40+
def __getmetrics__(self):
41+
return_metrics = {}
42+
for key in self.metrics:
43+
return_metrics[key] = np.mean(self.tmp_scores[key])
44+
45+
return return_metrics
46+
47+
def __resetvalues__(self):
48+
for key in self.tmp_scores:
49+
self.tmp_scores[key] = []

0 commit comments

Comments
 (0)