1- import argparse
2- from pathlib import Path
3-
41import numpy as np
52import torch as th
63import torch .nn as nn
7- import wandb
84from torch .utils .data import DataLoader
5+ from torchvision import transforms
6+ from tqdm import tqdm
97
10- from utils import MetricWrapper , createfolders , load_data , load_model
8+ import wandb
9+ from utils import MetricWrapper , createfolders , get_args , load_data , load_model
1110
1211
1312def main ():
@@ -23,122 +22,41 @@ def main():
2322 ------
2423
2524 """
26- parser = argparse .ArgumentParser (
27- prog = "" ,
28- description = "" ,
29- epilog = "" ,
30- )
31- # Structuture related values
32- parser .add_argument (
33- "--datafolder" ,
34- type = Path ,
35- default = "Data" ,
36- help = "Path to where data will be saved during training." ,
37- )
38- parser .add_argument (
39- "--resultfolder" ,
40- type = Path ,
41- default = "Results" ,
42- help = "Path to where results will be saved during evaluation." ,
43- )
44- parser .add_argument (
45- "--modelfolder" ,
46- type = Path ,
47- default = "Experiments" ,
48- help = "Path to where model weights will be saved at the end of training." ,
49- )
50- parser .add_argument (
51- "--savemodel" ,
52- type = bool ,
53- default = False ,
54- help = "Whether model should be saved or not." ,
55- )
56-
57- parser .add_argument (
58- "--download-data" ,
59- type = bool ,
60- default = False ,
61- help = "Whether the data should be downloaded or not. Might cause code to start a bit slowly." ,
62- )
6325
64- # Data/Model specific values
65- parser .add_argument (
66- "--modelname" ,
67- type = str ,
68- default = "MagnusModel" ,
69- choices = ["MagnusModel" , "ChristianModel" , "SolveigModel" ],
70- help = "Model which to be trained on" ,
71- )
72- parser .add_argument (
73- "--dataset" ,
74- type = str ,
75- default = "svhn" ,
76- choices = ["svhn" , "usps_0-6" , "uspsh5_7_9" , "mnist_0-3" ],
77- help = "Which dataset to train the model on." ,
78- )
79-
80- parser .add_argument (
81- "--metric" ,
82- type = str ,
83- default = ["entropy" ],
84- choices = ["entropy" , "f1" , "recall" , "precision" , "accuracy" ],
85- nargs = "+" ,
86- help = "Which metric to use for evaluation" ,
87- )
88-
89- # Training specific values
90- parser .add_argument (
91- "--epoch" ,
92- type = int ,
93- default = 20 ,
94- help = "Amount of training epochs the model will do." ,
95- )
96- parser .add_argument (
97- "--learning_rate" ,
98- type = float ,
99- default = 0.001 ,
100- help = "Learning rate parameter for model training." ,
101- )
102- parser .add_argument (
103- "--batchsize" ,
104- type = int ,
105- default = 64 ,
106- help = "Amount of training images loaded in one go" ,
107- )
108- parser .add_argument (
109- "--device" ,
110- type = str ,
111- default = "cpu" ,
112- choices = ["cuda" , "cpu" , "mps" ],
113- help = "Which device to run the training on." ,
114- )
115- parser .add_argument (
116- "--dry_run" ,
117- action = "store_true" ,
118- help = "If true, the code will not run the training loop." ,
119- )
120-
121- args = parser .parse_args ()
26+ args = get_args ()
12227
12328 createfolders (args .datafolder , args .resultfolder , args .modelfolder )
12429
12530 device = args .device
12631
127- metrics = MetricWrapper (* args .metric )
32+ if args .dataset .lower () in ["usps_0-6" , "uspsh5_7_9" ]:
33+ augmentations = transforms .Compose (
34+ [
35+ transforms .Resize ((16 , 16 )),
36+ transforms .ToTensor (),
37+ ]
38+ )
39+ else :
40+ augmentations = transforms .Compose ([transforms .ToTensor ()])
12841
12942 # Dataset
13043 traindata = load_data (
13144 args .dataset ,
13245 train = True ,
13346 data_path = args .datafolder ,
13447 download = args .download_data ,
48+ transform = augmentations ,
13549 )
13650 validata = load_data (
13751 args .dataset ,
13852 train = False ,
13953 data_path = args .datafolder ,
54+ download = args .download_data ,
55+ transform = augmentations ,
14056 )
14157
58+ metrics = MetricWrapper (* args .metric , num_classes = traindata .num_classes )
59+
14260 # Find the shape of the data, if is 2D, add a channel dimension
14361 data_shape = traindata [0 ][0 ].shape
14462 if len (data_shape ) == 2 :
@@ -168,37 +86,75 @@ def main():
16886
16987 # This allows us to load all the components without running the training loop
17088 if args .dry_run :
171- print ("Dry run completed" )
172- exit (0 )
89+ dry_run_loader = DataLoader (
90+ traindata ,
91+ batch_size = 20 ,
92+ shuffle = True ,
93+ pin_memory = True ,
94+ drop_last = True ,
95+ )
17396
174- wandb .init (project = "" , tags = [])
175- wandb .watch (model )
97+ for x , y in tqdm (dry_run_loader , desc = "Dry run" , total = 1 ):
98+ x , y = x .to (device ), y .to (device )
99+ logits = model .forward (x )
176100
101+ loss = criterion (logits , y )
102+ loss .backward ()
103+
104+ optimizer .step ()
105+ optimizer .zero_grad (set_to_none = True )
106+
107+ metrics (y , logits )
108+
109+ break
110+ print (metrics .accumulate ())
111+ print ("Dry run completed successfully." )
112+ exit ()
113+
114+ # wandb.login(key=WANDB_API)
115+ wandb .init (
116+ entity = "ColabCode-org" ,
117+ # entity="FYS-8805 Exam",
118+ project = "Test" ,
119+ tags = [args .modelname , args .dataset ]
120+ )
121+ wandb .watch (model )
122+ exit ()
177123 for epoch in range (args .epoch ):
178124 # Training loop start
179125 trainingloss = []
180126 model .train ()
181- for x , y in trainloader :
127+ for x , y in tqdm ( trainloader , desc = "Training" ) :
182128 x , y = x .to (device ), y .to (device )
183- pred = model .forward (x )
129+ logits = model .forward (x )
184130
185- loss = criterion (y , pred )
131+ loss = criterion (logits , y )
186132 loss .backward ()
187133
188134 optimizer .step ()
189135 optimizer .zero_grad (set_to_none = True )
190136 trainingloss .append (loss .item ())
191137
138+ metrics (y , logits )
139+
140+ wandb .log (metrics .accumulate (str_prefix = "Train " ))
141+ metrics .reset ()
142+
192143 evalloss = []
193144 # Eval loop start
194145 model .eval ()
195146 with th .no_grad ():
196- for x , y in valiloader :
147+ for x , y in tqdm ( valiloader , desc = "Validation" ) :
197148 x , y = x .to (device ), y .to (device )
198- pred = model .forward (x )
199- loss = criterion (y , pred )
149+ logits = model .forward (x )
150+ loss = criterion (logits , y )
200151 evalloss .append (loss .item ())
201152
153+ metrics (y , logits )
154+
155+ wandb .log (metrics .accumulate (str_prefix = "Evaluation " ))
156+ metrics .reset ()
157+
202158 wandb .log (
203159 {
204160 "Epoch" : epoch ,
0 commit comments