1- import argparse
21from pathlib import Path
32
43import numpy as np
98from torchvision import transforms
109from tqdm import tqdm
1110
12- from utils import MetricWrapper , createfolders , load_data , load_model
11+ from utils import MetricWrapper , createfolders , get_args , load_data , load_model
1312
1413
1514def main ():
@@ -25,113 +24,21 @@ def main():
2524 ------
2625
2726 """
28- parser = argparse .ArgumentParser (
29- prog = "" ,
30- description = "" ,
31- epilog = "" ,
32- )
33- # Structuture related values
34- parser .add_argument (
35- "--datafolder" ,
36- type = Path ,
37- default = "Data" ,
38- help = "Path to where data will be saved during training." ,
39- )
40- parser .add_argument (
41- "--resultfolder" ,
42- type = Path ,
43- default = "Results" ,
44- help = "Path to where results will be saved during evaluation." ,
45- )
46- parser .add_argument (
47- "--modelfolder" ,
48- type = Path ,
49- default = "Experiments" ,
50- help = "Path to where model weights will be saved at the end of training." ,
51- )
52- parser .add_argument (
53- "--savemodel" ,
54- action = "store_true" ,
55- help = "Whether model should be saved or not." ,
56- )
57-
58- parser .add_argument (
59- "--download-data" ,
60- action = "store_true" ,
61- help = "Whether the data should be downloaded or not. Might cause code to start a bit slowly." ,
62- )
63-
64- # Data/Model specific values
65- parser .add_argument (
66- "--modelname" ,
67- type = str ,
68- default = "MagnusModel" ,
69- choices = ["MagnusModel" , "ChristianModel" , "SolveigModel" ],
70- help = "Model which to be trained on" ,
71- )
72- parser .add_argument (
73- "--dataset" ,
74- type = str ,
75- default = "svhn" ,
76- choices = ["svhn" , "usps_0-6" , "uspsh5_7_9" , "mnist_0-3" ],
77- help = "Which dataset to train the model on." ,
78- )
79-
80- parser .add_argument (
81- "--metric" ,
82- type = str ,
83- default = ["entropy" ],
84- choices = ["entropy" , "f1" , "recall" , "precision" , "accuracy" ],
85- nargs = "+" ,
86- help = "Which metric to use for evaluation" ,
87- )
88-
89- # Training specific values
90- parser .add_argument (
91- "--epoch" ,
92- type = int ,
93- default = 20 ,
94- help = "Amount of training epochs the model will do." ,
95- )
96- parser .add_argument (
97- "--learning_rate" ,
98- type = float ,
99- default = 0.001 ,
100- help = "Learning rate parameter for model training." ,
101- )
102- parser .add_argument (
103- "--batchsize" ,
104- type = int ,
105- default = 64 ,
106- help = "Amount of training images loaded in one go" ,
107- )
108- parser .add_argument (
109- "--device" ,
110- type = str ,
111- default = "cpu" ,
112- choices = ["cuda" , "cpu" , "mps" ],
113- help = "Which device to run the training on." ,
114- )
115- parser .add_argument (
116- "--dry_run" ,
117- action = "store_true" ,
118- help = "If true, the code will not run the training loop." ,
119- )
120-
121- args = parser .parse_args ()
27+ args = get_args ()
12228
12329 createfolders (args .datafolder , args .resultfolder , args .modelfolder )
12430
12531 device = args .device
12632
127- metrics = MetricWrapper (* args .metric )
128-
129- augmentations = transforms .Compose (
130- [
131- transforms .Resize ((16 , 16 )), # At least for USPS
132- transforms .ToTensor (),
133- ]
134- )
33+ if args .dataset .lower () in ["usps_0-6" , "uspsh5_7_9" ]:
34+ augmentations = transforms .Compose (
35+ [
36+ transforms .Resize ((16 , 16 )),
37+ transforms .ToTensor (),
38+ ]
39+ )
40+ else :
41+ augmentations = transforms .Compose ([transforms .ToTensor ()])
13542
13643 # Dataset
13744 traindata = load_data (
@@ -149,6 +56,8 @@ def main():
14956 transform = augmentations ,
15057 )
15158
59+ metrics = MetricWrapper (* args .metric , num_classes = traindata .num_classes )
60+
15261 # Find the shape of the data, if is 2D, add a channel dimension
15362 data_shape = traindata [0 ][0 ].shape
15463 if len (data_shape ) == 2 :
@@ -180,28 +89,32 @@ def main():
18089 if args .dry_run :
18190 dry_run_loader = DataLoader (
18291 traindata ,
183- batch_size = 1 ,
92+ batch_size = 20 ,
18493 shuffle = True ,
18594 pin_memory = True ,
18695 drop_last = True ,
18796 )
18897
18998 for x , y in tqdm (dry_run_loader , desc = "Dry run" , total = 1 ):
19099 x , y = x .to (device ), y .to (device )
191- pred = model .forward (x )
100+ logits = model .forward (x )
192101
193- loss = criterion (y , pred )
102+ loss = criterion (logits , y )
194103 loss .backward ()
195104
196105 optimizer .step ()
197106 optimizer .zero_grad (set_to_none = True )
198107
199- break
108+ preds = th .argmax (logits , dim = 1 )
109+ metrics (y , preds )
200110
111+ break
112+ print (metrics .__getmetrics__ ())
201113 print ("Dry run completed successfully." )
202114 exit (0 )
203115
204- wandb .init (project = "" , tags = [])
116+ wandb .login (key = WANDB_API )
117+ wandb .init (entity = "ColabCode" , project = "Jan" , tags = [args .modelname , args .dataset ])
205118 wandb .watch (model )
206119
207120 for epoch in range (args .epoch ):
@@ -210,25 +123,37 @@ def main():
210123 model .train ()
211124 for x , y in tqdm (trainloader , desc = "Training" ):
212125 x , y = x .to (device ), y .to (device )
213- pred = model .forward (x )
126+ logits = model .forward (x )
214127
215- loss = criterion (y , pred )
128+ loss = criterion (logits , y )
216129 loss .backward ()
217130
218131 optimizer .step ()
219132 optimizer .zero_grad (set_to_none = True )
220133 trainingloss .append (loss .item ())
221134
135+ preds = th .argmax (logits , dim = 1 )
136+ metrics (y , preds )
137+
138+ wandb .log (metrics .__getmetrics__ (str_prefix = "Train " ))
139+ metrics .__resetvalues__ ()
140+
222141 evalloss = []
223142 # Eval loop start
224143 model .eval ()
225144 with th .no_grad ():
226145 for x , y in tqdm (valiloader , desc = "Validation" ):
227146 x , y = x .to (device ), y .to (device )
228- pred = model .forward (x )
229- loss = criterion (y , pred )
147+ logits = model .forward (x )
148+ loss = criterion (logits , y )
230149 evalloss .append (loss .item ())
231150
151+ preds = th .argmax (logits , dim = 1 )
152+ metrics (y , preds )
153+
154+ wandb .log (metrics .__getmetrics__ (str_prefix = "Evaluation " ))
155+ metrics .__resetvalues__ ()
156+
232157 wandb .log (
233158 {
234159 "Epoch" : epoch ,
0 commit comments