1- import argparse
21from pathlib import Path
32
43import numpy as np
54import torch as th
65import torch .nn as nn
76import wandb
87from torch .utils .data import DataLoader
8+ from torchvision import transforms
9+ from tqdm import tqdm
910
10- from utils import MetricWrapper , createfolders , load_data , load_model
11+ from utils import MetricWrapper , createfolders , get_args , load_data , load_model
1112
1213
1314def main ():
@@ -23,202 +24,146 @@ def main():
2324 ------
2425
2526 """
26- parser = argparse .ArgumentParser (
27- prog = "" ,
28- description = "" ,
29- epilog = "" ,
30- )
31- # Structuture related values
32- parser .add_argument (
33- "--datafolder" ,
34- type = Path ,
35- default = "Data" ,
36- help = "Path to where data will be saved during training." ,
37- )
38- parser .add_argument (
39- "--resultfolder" ,
40- type = Path ,
41- default = "Results" ,
42- help = "Path to where results will be saved during evaluation." ,
43- )
44- parser .add_argument (
45- "--modelfolder" ,
46- type = Path ,
47- default = "Experiments" ,
48- help = "Path to where model weights will be saved at the end of training." ,
49- )
50- parser .add_argument (
51- "--savemodel" ,
52- type = bool ,
53- default = False ,
54- help = "Whether model should be saved or not." ,
55- )
56-
57- parser .add_argument (
58- "--download-data" ,
59- type = bool ,
60- default = False ,
61- help = "Whether the data should be downloaded or not. Might cause code to start a bit slowly." ,
62- )
63-
64- # Data/Model specific values
65- parser .add_argument (
66- "--modelname" ,
67- type = str ,
68- default = "MagnusModel" ,
69- choices = ["MagnusModel" , "ChristianModel" ],
70- help = "Model which to be trained on" ,
71- )
72- parser .add_argument (
73- "--dataset" ,
74- type = str ,
75- default = "svhn" ,
76- choices = ["svhn" , "usps_0-6" ],
77- help = "Which dataset to train the model on." ,
78- )
7927
80- parser .add_argument (
81- "--metric" ,
82- type = str ,
83- default = "entropy" ,
84- choices = ["entropy" , "f1" , "recall" , "precision" , "accuracy" ],
85- nargs = "+" ,
86- help = "Which metric to use for evaluation" ,
87- )
88- parser .add_argument ('--imagesize' ,
89- type = int ,
90- default = 28 ,
91- help = 'Size of images' )
92- parser .add_argument ('--imagechannels' ,
93- type = int ,
94- default = 1 ,
95- choices = [1 ,3 ],
96- help = 'Number of color channels in the image.' )
97-
98-
99-
100-
101- # Training specific values
102- parser .add_argument (
103- "--epoch" ,
104- type = int ,
105- default = 20 ,
106- help = "Amount of training epochs the model will do." ,
107- )
108- parser .add_argument (
109- "--learning_rate" ,
110- type = float ,
111- default = 0.001 ,
112- help = "Learning rate parameter for model training." ,
113- )
114- parser .add_argument (
115- "--batchsize" ,
116- type = int ,
117- default = 64 ,
118- help = "Amount of training images loaded in one go" ,
119- )
120- parser .add_argument (
121- "--device" ,
122- type = str ,
123- default = "cpu" ,
124- choices = ["cuda" , "cpu" , "mps" ],
125- help = "Which device to run the training on." ,
126- )
127- parser .add_argument (
128- "--dry_run" ,
129- action = "store_true" ,
130- help = "If true, the code will not run the training loop." ,
131- )
28+ args = get_args ()
13229
133- args = parser .parse_args ()
13430
13531 createfolders (args .datafolder , args .resultfolder , args .modelfolder )
13632
13733 device = args .device
13834
139- metrics = MetricWrapper (* args .metric )
35+ if args .dataset .lower () in ["usps_0-6" , "uspsh5_7_9" ]:
36+ augmentations = transforms .Compose (
37+ [
38+ transforms .Resize ((16 , 16 )),
39+ transforms .ToTensor (),
40+ ]
41+ )
42+ else :
43+ augmentations = transforms .Compose ([transforms .ToTensor ()])
14044
14145 # Dataset
14246 traindata = load_data (
14347 args .dataset ,
14448 train = True ,
14549 data_path = args .datafolder ,
14650 download = args .download_data ,
51+ transform = augmentations ,
14752 )
14853 validata = load_data (
14954 args .dataset ,
15055 train = False ,
15156 data_path = args .datafolder ,
57+ download = args .download_data ,
58+ transform = augmentations ,
15259 )
15360
154- # Find number of channels in the dataset
155- if len (traindata [0 ][0 ].shape ) == 2 :
156- channels = 1
157- else :
158- channels = traindata [0 ][0 ].shape [0 ]
61+ metrics = MetricWrapper (* args .metric , num_classes = traindata .num_classes )
62+
63+ # Find the shape of the data, if is 2D, add a channel dimension
64+ data_shape = traindata [0 ][0 ].shape
65+ if len (data_shape ) == 2 :
66+ data_shape = (1 , * data_shape )
15967
16068 # load model
16169 model = load_model (
16270 args .modelname ,
163- in_channels = channels ,
71+ image_shape = data_shape ,
16472 num_classes = traindata .num_classes ,
16573 )
16674 model .to (device )
16775
168- trainloader = DataLoader (traindata ,
169- batch_size = args .batchsize ,
170- shuffle = True ,
171- pin_memory = True ,
172- drop_last = True )
173- valiloader = DataLoader (validata ,
174- batch_size = args .batchsize ,
175- shuffle = False ,
176- pin_memory = True )
76+ trainloader = DataLoader (
77+ traindata ,
78+ batch_size = args .batchsize ,
79+ shuffle = True ,
80+ pin_memory = True ,
81+ drop_last = True ,
82+ )
83+ valiloader = DataLoader (
84+ validata , batch_size = args .batchsize , shuffle = False , pin_memory = True
85+ )
17786
17887 criterion = nn .CrossEntropyLoss ()
17988 optimizer = th .optim .Adam (model .parameters (), lr = args .learning_rate )
18089
18190 # This allows us to load all the components without running the training loop
18291 if args .dry_run :
183- print ("Dry run completed" )
92+ dry_run_loader = DataLoader (
93+ traindata ,
94+ batch_size = 20 ,
95+ shuffle = True ,
96+ pin_memory = True ,
97+ drop_last = True ,
98+ )
99+
100+ for x , y in tqdm (dry_run_loader , desc = "Dry run" , total = 1 ):
101+ x , y = x .to (device ), y .to (device )
102+ logits = model .forward (x )
103+
104+ loss = criterion (logits , y )
105+ loss .backward ()
106+
107+ optimizer .step ()
108+ optimizer .zero_grad (set_to_none = True )
109+
110+ preds = th .argmax (logits , dim = 1 )
111+ metrics (y , preds )
112+
113+ break
114+ print (metrics .accumulate ())
115+ print ("Dry run completed successfully." )
184116 exit (0 )
185117
186- wandb .init ( project = '' ,
187- tags = [])
118+ wandb .login ( key = WANDB_API )
119+ wandb . init ( entity = "ColabCode" , project = "Jan" , tags = [args . modelname , args . dataset ])
188120 wandb .watch (model )
189121
190122 for epoch in range (args .epoch ):
191-
192123 # Training loop start
193124 trainingloss = []
194125 model .train ()
195- for x , y in trainloader :
126+ for x , y in tqdm ( trainloader , desc = "Training" ) :
196127 x , y = x .to (device ), y .to (device )
197- pred = model .forward (x )
128+ logits = model .forward (x )
198129
199- loss = criterion (y , pred )
130+ loss = criterion (logits , y )
200131 loss .backward ()
201132
202133 optimizer .step ()
203134 optimizer .zero_grad (set_to_none = True )
204135 trainingloss .append (loss .item ())
205136
137+ preds = th .argmax (logits , dim = 1 )
138+ metrics (y , preds )
139+
140+ wandb .log (metrics .accumulate (str_prefix = "Train " ))
141+ metrics .reset ()
142+
206143 evalloss = []
207144 # Eval loop start
208145 model .eval ()
209146 with th .no_grad ():
210- for x , y in valiloader :
147+ for x , y in tqdm ( valiloader , desc = "Validation" ) :
211148 x , y = x .to (device ), y .to (device )
212- pred = model .forward (x )
213- loss = criterion (y , pred )
149+ logits = model .forward (x )
150+ loss = criterion (logits , y )
214151 evalloss .append (loss .item ())
215152
216- wandb .log ({
217- 'Epoch' : epoch ,
218- 'Train loss' : np .mean (trainingloss ),
219- 'Evaluation Loss' : np .mean (evalloss )
220- })
153+ preds = th .argmax (logits , dim = 1 )
154+ metrics (y , preds )
155+
156+ wandb .log (metrics .accumulate (str_prefix = "Evaluation " ))
157+ metrics .reset ()
158+
159+ wandb .log (
160+ {
161+ "Epoch" : epoch ,
162+ "Train loss" : np .mean (trainingloss ),
163+ "Evaluation Loss" : np .mean (evalloss ),
164+ }
165+ )
221166
222167
223- if __name__ == ' __main__' :
168+ if __name__ == " __main__" :
224169 main ()
0 commit comments