1- from pathlib import Path
2-
31import numpy as np
42import torch as th
53import torch .nn as nn
6- import wandb
74from torch .utils .data import DataLoader
85from torchvision import transforms
96from tqdm import tqdm
107
8+ import wandb
119from utils import MetricWrapper , createfolders , get_args , load_data , load_model
1210
1311
@@ -27,35 +25,25 @@ def main():
2725
2826 args = get_args ()
2927
30-
3128 createfolders (args .datafolder , args .resultfolder , args .modelfolder )
3229
3330 device = args .device
3431
3532 if args .dataset .lower () in ["usps_0-6" , "uspsh5_7_9" ]:
36- augmentations = transforms .Compose (
33+ transform = transforms .Compose (
3734 [
3835 transforms .Resize ((16 , 16 )),
3936 transforms .ToTensor (),
4037 ]
4138 )
4239 else :
43- augmentations = transforms .Compose ([transforms .ToTensor ()])
40+ transform = transforms .Compose ([transforms .ToTensor ()])
4441
45- # Dataset
46- traindata = load_data (
47- args .dataset ,
48- train = True ,
49- data_path = args .datafolder ,
50- download = args .download_data ,
51- transform = augmentations ,
52- )
53- validata = load_data (
42+ traindata , validata , testdata = load_data (
5443 args .dataset ,
55- train = False ,
56- data_path = args .datafolder ,
57- download = args .download_data ,
58- transform = augmentations ,
44+ data_dir = args .datafolder ,
45+ transform = transform ,
46+ val_size = args .val_size ,
5947 )
6048
6149 metrics = MetricWrapper (* args .metric , num_classes = traindata .num_classes , macro_averaging = args .macro_averaging )
@@ -83,6 +71,9 @@ def main():
8371 valiloader = DataLoader (
8472 validata , batch_size = args .batchsize , shuffle = False , pin_memory = True
8573 )
74+ testloader = DataLoader (
75+ testdata , batch_size = args .batchsize , shuffle = False , pin_memory = True
76+ )
8677
8778 criterion = nn .CrossEntropyLoss ()
8879 optimizer = th .optim .Adam (model .parameters (), lr = args .learning_rate )
@@ -107,18 +98,22 @@ def main():
10798 optimizer .step ()
10899 optimizer .zero_grad (set_to_none = True )
109100
110- preds = th .argmax (logits , dim = 1 )
111- metrics (y , preds )
101+ metrics (y , logits )
112102
113103 break
114104 print (metrics .accumulate ())
115105 print ("Dry run completed successfully." )
116- exit (0 )
117-
118- wandb .login (key = WANDB_API )
119- wandb .init (entity = "ColabCode" , project = "Jan" , tags = [args .modelname , args .dataset ])
106+ exit ()
107+
108+ # wandb.login(key=WANDB_API)
109+ wandb .init (
110+ entity = "ColabCode-org" ,
111+ # entity="FYS-8805 Exam",
112+ project = "Test" ,
113+ tags = [args .modelname , args .dataset ]
114+ )
120115 wandb .watch (model )
121-
116+ exit ()
122117 for epoch in range (args .epoch ):
123118 # Training loop start
124119 trainingloss = []
@@ -134,36 +129,50 @@ def main():
134129 optimizer .zero_grad (set_to_none = True )
135130 trainingloss .append (loss .item ())
136131
137- preds = th .argmax (logits , dim = 1 )
138- metrics (y , preds )
132+ metrics (y , logits )
139133
140134 wandb .log (metrics .accumulate (str_prefix = "Train " ))
141135 metrics .reset ()
142136
143- evalloss = []
144- # Eval loop start
137+ valloss = []
138+ # Validation loop start
145139 model .eval ()
146140 with th .no_grad ():
147141 for x , y in tqdm (valiloader , desc = "Validation" ):
148142 x , y = x .to (device ), y .to (device )
149143 logits = model .forward (x )
150144 loss = criterion (logits , y )
151- evalloss .append (loss .item ())
145+ valloss .append (loss .item ())
152146
153- preds = th .argmax (logits , dim = 1 )
154- metrics (y , preds )
147+ metrics (y , logits )
155148
156- wandb .log (metrics .accumulate (str_prefix = "Evaluation " ))
149+ wandb .log (metrics .accumulate (str_prefix = "Validation " ))
157150 metrics .reset ()
158151
159152 wandb .log (
160153 {
161154 "Epoch" : epoch ,
162155 "Train loss" : np .mean (trainingloss ),
163- "Evaluation Loss " : np .mean (evalloss ),
156+ "Validation loss " : np .mean (valloss ),
164157 }
165158 )
166159
160+ testloss = []
161+ model .eval ()
162+ with th .no_grad ():
163+ for x , y in tqdm (testloader , desc = "Testing" ):
164+ x , y = x .to (device ), y .to (device )
165+ logits = model .forward (x )
166+ loss = criterion (logits , y )
167+ testloss .append (loss .item ())
168+
169+ preds = th .argmax (logits , dim = 1 )
170+ metrics (y , preds )
171+
172+ wandb .log (metrics .accumulate (str_prefix = "Test " ))
173+ metrics .reset ()
174+ wandb .log ({"Test loss" : np .mean (testloss )})
175+
167176
168177if __name__ == "__main__" :
169178 main ()
0 commit comments