55from torch .utils .data import DataLoader
66from torchvision import transforms
77from tqdm import tqdm
8+ from wandb_api import WANDB_API
89
910from utils import MetricWrapper , createfolders , get_args , load_data , load_model
1011
@@ -29,33 +30,38 @@ def main():
2930
3031 device = args .device
3132
32- if args .dataset .lower () in ["usps_0-6" , "uspsh5_7_9 " ]:
33- augmentations = transforms .Compose (
33+ if args .dataset .lower () in ["usps_0-6" , "usps_7-9 " ]:
34+ transform = transforms .Compose (
3435 [
3536 transforms .Resize ((16 , 16 )),
3637 transforms .ToTensor (),
3738 ]
3839 )
3940 else :
40- augmentations = transforms .Compose ([transforms .ToTensor ()])
41+ transform = transforms .Compose ([transforms .ToTensor ()])
4142
42- # Dataset
43- traindata = load_data (
43+ traindata , validata , testdata = load_data (
4444 args .dataset ,
45- train = True ,
46- data_path = args .datafolder ,
47- download = args .download_data ,
48- transform = augmentations ,
49- )
50- validata = load_data (
51- args .dataset ,
52- train = False ,
53- data_path = args .datafolder ,
54- download = args .download_data ,
55- transform = augmentations ,
45+ data_dir = args .datafolder ,
46+ transform = transform ,
47+ val_size = args .val_size ,
5648 )
5749
58- metrics = MetricWrapper (traindata .num_classes , * args .metric )
50+ train_metrics = MetricWrapper (
51+ * args .metric ,
52+ num_classes = traindata .num_classes ,
53+ macro_averaging = args .macro_averaging ,
54+ )
55+ val_metrics = MetricWrapper (
56+ * args .metric ,
57+ num_classes = traindata .num_classes ,
58+ macro_averaging = args .macro_averaging ,
59+ )
60+ test_metrics = MetricWrapper (
61+ * args .metric ,
62+ num_classes = traindata .num_classes ,
63+ macro_averaging = args .macro_averaging ,
64+ )
5965
6066 # Find the shape of the data, if is 2D, add a channel dimension
6167 data_shape = traindata [0 ][0 ].shape
@@ -80,6 +86,9 @@ def main():
8086 valiloader = DataLoader (
8187 validata , batch_size = args .batchsize , shuffle = False , pin_memory = True
8288 )
89+ testloader = DataLoader (
90+ testdata , batch_size = args .batchsize , shuffle = False , pin_memory = True
91+ )
8392
8493 criterion = nn .CrossEntropyLoss ()
8594 optimizer = th .optim .Adam (model .parameters (), lr = args .learning_rate )
@@ -104,17 +113,17 @@ def main():
104113 optimizer .step ()
105114 optimizer .zero_grad (set_to_none = True )
106115
107- metrics (y , logits )
116+ train_metrics (y , logits )
108117
109118 break
110- print (metrics .accumulate ())
119+ print (train_metrics .accumulate ())
111120 print ("Dry run completed successfully." )
112121 exit ()
113122
114123 # wandb.login(key=WANDB_API)
115124 wandb .init (
116125 entity = "ColabCode" ,
117- project = "Magnus Runs" ,
126+ project = args . run_name ,
118127 tags = [args .modelname , args .dataset ],
119128 config = args ,
120129 )
@@ -135,33 +144,49 @@ def main():
135144 optimizer .zero_grad (set_to_none = True )
136145 trainingloss .append (loss .item ())
137146
138- metrics (y , logits )
147+ train_metrics (y , logits )
139148
140- wandb .log (metrics .accumulate (str_prefix = "Train " ))
141- metrics .reset ()
142-
143- evalloss = []
144- # Eval loop start
149+ valloss = []
150+ # Validation loop start
145151 model .eval ()
146152 with th .no_grad ():
147153 for x , y in tqdm (valiloader , desc = "Validation" ):
148154 x , y = x .to (device ), y .to (device )
149155 logits = model .forward (x )
150156 loss = criterion (logits , y )
151- evalloss .append (loss .item ())
152-
153- metrics (y , logits )
157+ valloss .append (loss .item ())
154158
155- wandb .log (metrics .accumulate (str_prefix = "Evaluation " ))
156- metrics .reset ()
159+ val_metrics (y , logits )
157160
158161 wandb .log (
159162 {
160163 "Epoch" : epoch ,
161164 "Train loss" : np .mean (trainingloss ),
162- "Evaluation Loss " : np .mean (evalloss ),
165+ "Validation loss " : np .mean (valloss ),
163166 }
167+ | train_metrics .__getmetrics__ (str_prefix = "Train " )
168+ | val_metrics .__getmetrics__ (str_prefix = "Validation " )
164169 )
170+ train_metrics .__resetmetrics__ ()
171+ val_metrics .__resetmetrics__ ()
172+
173+ testloss = []
174+ model .eval ()
175+ with th .no_grad ():
176+ for x , y in tqdm (testloader , desc = "Testing" ):
177+ x , y = x .to (device ), y .to (device )
178+ logits = model .forward (x )
179+ loss = criterion (logits , y )
180+ testloss .append (loss .item ())
181+
182+ preds = th .argmax (logits , dim = 1 )
183+ test_metrics (y , preds )
184+
185+ wandb .log (
186+ {"Epoch" : 1 , "Test loss" : np .mean (testloss )}
187+ | test_metrics .__getmetrics__ (str_prefix = "Test " )
188+ )
189+ test_metrics .__resetmetrics__ ()
165190
166191
167192if __name__ == "__main__" :
0 commit comments