11import torch as th
22import torch .nn as nn
3-
3+ from torch . utils . data import DataLoader
44import argparse
5- from utils import load_metric , load_model , createfolders
6-
7-
8-
9-
10-
11-
12-
13-
14-
15-
5+ import wandb
6+ import numpy as np
7+ from utils import MetricWrapper , load_model , load_data , createfolders
168
179
1810def main ():
@@ -38,16 +30,28 @@ def main():
3830 parser .add_argument ('--resultfolder' , type = str , default = 'Results/' , help = 'Path to where results will be saved during evaluation.' )
3931 parser .add_argument ('--modelfolder' , type = str , default = 'Experiments/' , help = 'Path to where model weights will be saved at the end of training.' )
4032 parser .add_argument ('--savemodel' , type = bool , default = False , help = 'Whether model should be saved or not.' )
33+
4134 parser .add_argument ('--download-data' , type = bool , default = False , help = 'Whether the data should be downloaded or not. Might cause code to start a bit slowly.' )
4235
4336 #Data/Model specific values
4437 parser .add_argument ('--modelname' , type = str , default = 'MagnusModel' ,
4538 choices = ['MagnusModel' ], help = "Model which to be trained on" )
39+ parser .add_argument ('--dataset' , type = str , default = 'svhn' ,
40+ choices = ['svhn' ], help = 'Which dataset to train the model on.' )
41+
42+ parser .add_argument ('--EntropyPrediction' , type = bool , default = True , help = 'Include the Entropy Prediction metric in evaluation' )
43+ parser .add_argument ('--F1Score' , type = bool , default = True , help = 'Include the F1Score metric in evaluation' )
44+ parser .add_argument ('--Recall' , type = bool , default = True , help = 'Include the Recall metric in evaluation' )
45+ parser .add_argument ('--Precision' , type = bool , default = True , help = 'Include the Precision metric in evaluation' )
46+ parser .add_argument ('--Accuracy' , type = bool , default = True , help = 'Include the Accuracy metric in evaluation' )
4647
4748 #Training specific values
4849 parser .add_argument ('--epoch' , type = int , default = 20 , help = 'Amount of training epochs the model will do.' )
4950 parser .add_argument ('--learning_rate' , type = float , default = 0.001 , help = 'Learning rate parameter for model training.' )
51+ parser .add_argument ('--batchsize' , type = int , default = 64 , help = 'Amount of training images loaded in one go' )
52+
5053 args = parser .parse_args ()
54+
5155
5256 createfolders (args )
5357
@@ -57,19 +61,68 @@ def main():
5761 model = load_model ()
5862 model .to (device )
5963
64+ metrics = MetricWrapper (
65+ EntropyPred = args .EntropyPrediction ,
66+ F1Score = args .F1Score ,
67+ Recall = args .Recall ,
68+ Precision = args .Precision ,
69+ Accuracy = args .Accuracy
70+ )
71+
72+ #Dataset
73+ traindata = load_data (args .dataset )
74+ validata = load_data (args .dataset )
75+
76+ trainloader = DataLoader (traindata ,
77+ batch_size = args .batchsize ,
78+ shuffle = True ,
79+ pin_memory = True ,
80+ drop_last = True )
81+ valiloader = DataLoader (validata ,
82+ batch_size = args .batchsize ,
83+ shuffle = False ,
84+ pin_memory = True )
6085
6186 criterion = nn .CrossEntropyLoss ()
6287 optimizer = th .optim .Adam (model .parameters (), lr = args .learning_rate )
63-
64-
65-
88+
89+
90+ wandb .init (project = '' ,
91+ tags = [])
92+ wandb .watch (model )
93+
6694 for epoch in range (args .epoch ):
6795
6896 #Training loop start
97+ trainingloss = []
98+ model .train ()
99+ for x , y in traindata :
100+ x , y = x .to (device ), y .to (device )
101+ pred = model .forward (x )
102+
103+ loss = criterion (y , pred )
104+ loss .backward ()
105+
106+ optimizer .step ()
107+ optimizer .zero_grad (set_to_none = True )
108+ trainingloss .append (loss .item ())
69109
110+ evalloss = []
70111 #Eval loop start
71-
72- pass
112+ model .eval ()
113+ with th .no_grad ():
114+ for x , y in valiloader :
115+ x = x .to (device )
116+ pred = model .forward (x )
117+ loss = criterion (y , pred )
118+ evalloss .append (loss .item ())
119+
120+ wandb .log ({
121+ 'Epoch' : epoch ,
122+ 'Train loss' : np .mean (trainingloss ),
123+ 'Evaluation Loss' : np .mean (evalloss )
124+ })
125+
73126
74127
75128if __name__ == '__main__' :
0 commit comments