1- import torch as th
2- import torch .nn as nn
3- from torch .utils .data import DataLoader
41import argparse
5- import wandb
2+ from pathlib import Path
3+
64import numpy as np
7- from utils import MetricWrapper , load_model , load_data , createfolders
5+ import torch as th
6+ import torch .nn as nn
7+ import wandb
8+ from torch .utils .data import DataLoader
9+
10+ from utils import MetricWrapper , createfolders , load_data , load_model
811
912
1013def main ():
11- '''
12-
14+ """
15+
1316 Parameters
1417 ----------
15-
18+
1619 Returns
1720 -------
18-
21+
1922 Raises
2023 ------
21-
22- '''
24+
25+ """
2326 parser = argparse .ArgumentParser (
24- prog = '' ,
25- description = '' ,
26- epilog = '' ,
27- )
28- #Structuture related values
29- parser .add_argument ('--datafolder' , type = str , default = 'Data/' , help = 'Path to where data will be saved during training.' )
30- parser .add_argument ('--resultfolder' , type = str , default = 'Results/' , help = 'Path to where results will be saved during evaluation.' )
31- parser .add_argument ('--modelfolder' , type = str , default = 'Experiments/' , help = 'Path to where model weights will be saved at the end of training.' )
32- parser .add_argument ('--savemodel' , type = bool , default = False , help = 'Whether model should be saved or not.' )
33-
34- parser .add_argument ('--download-data' , type = bool , default = False , help = 'Whether the data should be downloaded or not. Might cause code to start a bit slowly.' )
35-
36- #Data/Model specific values
37- parser .add_argument ('--modelname' , type = str , default = 'MagnusModel' ,
38- choices = ['MagnusModel' ], help = "Model which to be trained on" )
39- parser .add_argument ('--dataset' , type = str , default = 'svhn' ,
40- choices = ['svhn' ], help = 'Which dataset to train the model on.' )
41-
42- parser .add_argument ('--EntropyPrediction' , type = bool , default = True , help = 'Include the Entropy Prediction metric in evaluation' )
43- parser .add_argument ('--F1Score' , type = bool , default = True , help = 'Include the F1Score metric in evaluation' )
44- parser .add_argument ('--Recall' , type = bool , default = True , help = 'Include the Recall metric in evaluation' )
45- parser .add_argument ('--Precision' , type = bool , default = True , help = 'Include the Precision metric in evaluation' )
46- parser .add_argument ('--Accuracy' , type = bool , default = True , help = 'Include the Accuracy metric in evaluation' )
47-
48- #Training specific values
49- parser .add_argument ('--epoch' , type = int , default = 20 , help = 'Amount of training epochs the model will do.' )
50- parser .add_argument ('--learning_rate' , type = float , default = 0.001 , help = 'Learning rate parameter for model training.' )
51- parser .add_argument ('--batchsize' , type = int , default = 64 , help = 'Amount of training images loaded in one go' )
52-
27+ prog = "" ,
28+ description = "" ,
29+ epilog = "" ,
30+ )
31+ # Structuture related values
32+ parser .add_argument (
33+ "--datafolder" ,
34+ type = Path ,
35+ default = "Data" ,
36+ help = "Path to where data will be saved during training." ,
37+ )
38+ parser .add_argument (
39+ "--resultfolder" ,
40+ type = Path ,
41+ default = "Results" ,
42+ help = "Path to where results will be saved during evaluation." ,
43+ )
44+ parser .add_argument (
45+ "--modelfolder" ,
46+ type = Path ,
47+ default = "Experiments" ,
48+ help = "Path to where model weights will be saved at the end of training." ,
49+ )
50+ parser .add_argument (
51+ "--savemodel" ,
52+ type = bool ,
53+ default = False ,
54+ help = "Whether model should be saved or not." ,
55+ )
56+
57+ parser .add_argument (
58+ "--download-data" ,
59+ type = bool ,
60+ default = False ,
61+ help = "Whether the data should be downloaded or not. Might cause code to start a bit slowly." ,
62+ )
63+
64+ # Data/Model specific values
65+ parser .add_argument (
66+ "--modelname" ,
67+ type = str ,
68+ default = "MagnusModel" ,
69+ choices = ["MagnusModel" , "ChristianModel" ],
70+ help = "Model which to be trained on" ,
71+ )
72+ parser .add_argument (
73+ "--dataset" ,
74+ type = str ,
75+ default = "svhn" ,
76+ choices = ["svhn" , "usps_0-6" ],
77+ help = "Which dataset to train the model on." ,
78+ )
79+
80+ parser .add_argument (
81+ "--metric" ,
82+ type = str ,
83+ default = ["entropy" ],
84+ choices = ["entropy" , "f1" , "recall" , "precision" , "accuracy" ],
85+ nargs = "+" ,
86+ help = "Which metric to use for evaluation" ,
87+ )
88+
89+ # Training specific values
90+ parser .add_argument (
91+ "--epoch" ,
92+ type = int ,
93+ default = 20 ,
94+ help = "Amount of training epochs the model will do." ,
95+ )
96+ parser .add_argument (
97+ "--learning_rate" ,
98+ type = float ,
99+ default = 0.001 ,
100+ help = "Learning rate parameter for model training." ,
101+ )
102+ parser .add_argument (
103+ "--batchsize" ,
104+ type = int ,
105+ default = 64 ,
106+ help = "Amount of training images loaded in one go" ,
107+ )
108+ parser .add_argument (
109+ "--device" ,
110+ type = str ,
111+ default = "cpu" ,
112+ choices = ["cuda" , "cpu" , "mps" ],
113+ help = "Which device to run the training on." ,
114+ )
115+ parser .add_argument (
116+ "--dry_run" ,
117+ action = "store_true" ,
118+ help = "If true, the code will not run the training loop." ,
119+ )
120+
53121 args = parser .parse_args ()
54-
55-
56- createfolders (args )
57-
58- device = 'cuda' if th .cuda .is_available () else 'cpu'
59-
60- #load model
61- model = load_model ()
122+
123+ createfolders (args .datafolder , args .resultfolder , args .modelfolder )
124+
125+ device = args .device
126+
127+ metrics = MetricWrapper (* args .metric )
128+
129+ # Dataset
130+ traindata = load_data (
131+ args .dataset ,
132+ train = True ,
133+ data_path = args .datafolder ,
134+ download = args .download_data ,
135+ )
136+ validata = load_data (
137+ args .dataset ,
138+ train = False ,
139+ data_path = args .datafolder ,
140+ )
141+
142+ # Find number of channels in the dataset
143+ if len (traindata [0 ][0 ].shape ) == 2 :
144+ channels = 1
145+ else :
146+ channels = traindata [0 ][0 ].shape [0 ]
147+
148+ # load model
149+ model = load_model (
150+ args .modelname ,
151+ in_channels = channels ,
152+ num_classes = traindata .num_classes ,
153+ )
62154 model .to (device )
63-
64- metrics = MetricWrapper (
65- EntropyPred = args .EntropyPrediction ,
66- F1Score = args .F1Score ,
67- Recall = args .Recall ,
68- Precision = args .Precision ,
69- Accuracy = args .Accuracy
70- )
71-
72- #Dataset
73- traindata = load_data (args .dataset )
74- validata = load_data (args .dataset )
75-
155+
76156 trainloader = DataLoader (traindata ,
77157 batch_size = args .batchsize ,
78158 shuffle = True ,
@@ -82,48 +162,51 @@ def main():
82162 batch_size = args .batchsize ,
83163 shuffle = False ,
84164 pin_memory = True )
85-
165+
86166 criterion = nn .CrossEntropyLoss ()
87- optimizer = th .optim .Adam (model .parameters (), lr = args .learning_rate )
88-
89-
167+ optimizer = th .optim .Adam (model .parameters (), lr = args .learning_rate )
168+
169+ # This allows us to load all the components without running the training loop
170+ if args .dry_run :
171+ print ("Dry run completed" )
172+ exit (0 )
173+
90174 wandb .init (project = '' ,
91175 tags = [])
92176 wandb .watch (model )
93-
177+
94178 for epoch in range (args .epoch ):
95-
96- #Training loop start
179+
180+ # Training loop start
97181 trainingloss = []
98182 model .train ()
99- for x , y in traindata :
183+ for x , y in trainloader :
100184 x , y = x .to (device ), y .to (device )
101185 pred = model .forward (x )
102-
186+
103187 loss = criterion (y , pred )
104188 loss .backward ()
105-
189+
106190 optimizer .step ()
107191 optimizer .zero_grad (set_to_none = True )
108192 trainingloss .append (loss .item ())
109-
193+
110194 evalloss = []
111- #Eval loop start
195+ # Eval loop start
112196 model .eval ()
113197 with th .no_grad ():
114198 for x , y in valiloader :
115- x = x .to (device )
199+ x , y = x . to ( device ), y .to (device )
116200 pred = model .forward (x )
117201 loss = criterion (y , pred )
118202 evalloss .append (loss .item ())
119-
203+
120204 wandb .log ({
121205 'Epoch' : epoch ,
122206 'Train loss' : np .mean (trainingloss ),
123207 'Evaluation Loss' : np .mean (evalloss )
124208 })
125-
126209
127210
128211if __name__ == '__main__' :
129- main ()
212+ main ()
0 commit comments