@@ -43,16 +43,27 @@ def main():
4343 augmentations = transforms .Compose ([transforms .ToTensor ()])
4444
4545 # Dataset
46+ assert args .validation_split_percentage < 1.0 and args .validation_split_percentage > 0 , "Validation split should be in interval (0,1)"
4647 traindata = load_data (
4748 args .dataset ,
48- train = True ,
49+ split = "train" ,
50+ split_percentage = args .validation_split_percentage ,
4951 data_path = args .datafolder ,
5052 download = args .download_data ,
5153 transform = augmentations ,
5254 )
5355 validata = load_data (
5456 args .dataset ,
55- train = False ,
57+ split = "validation" ,
58+ split_percentage = args .validation_split_percentage ,
59+ data_path = args .datafolder ,
60+ download = args .download_data ,
61+ transform = augmentations ,
62+ )
63+ testdata = load_data (
64+ args .dataset ,
65+ split = "test" ,
66+ split_percentage = args .validation_split_percentage ,
5667 data_path = args .datafolder ,
5768 download = args .download_data ,
5869 transform = augmentations ,
@@ -83,6 +94,9 @@ def main():
8394 valiloader = DataLoader (
8495 validata , batch_size = args .batchsize , shuffle = False , pin_memory = True
8596 )
97+ testloader = DataLoader (
98+ testdata , batch_size = args .batchsize , shuffle = False , pin_memory = True
99+ )
86100
87101 criterion = nn .CrossEntropyLoss ()
88102 optimizer = th .optim .Adam (model .parameters (), lr = args .learning_rate )
@@ -140,30 +154,45 @@ def main():
140154 wandb .log (metrics .accumulate (str_prefix = "Train " ))
141155 metrics .reset ()
142156
143- evalloss = []
144- # Eval loop start
157+ valloss = []
158+ # Validation loop start
145159 model .eval ()
146160 with th .no_grad ():
147161 for x , y in tqdm (valiloader , desc = "Validation" ):
148162 x , y = x .to (device ), y .to (device )
149163 logits = model .forward (x )
150164 loss = criterion (logits , y )
151- evalloss .append (loss .item ())
165+ valloss .append (loss .item ())
152166
153167 preds = th .argmax (logits , dim = 1 )
154168 metrics (y , preds )
155169
156- wandb .log (metrics .accumulate (str_prefix = "Evaluation " ))
170+ wandb .log (metrics .accumulate (str_prefix = "Validation " ))
157171 metrics .reset ()
158172
159173 wandb .log (
160174 {
161175 "Epoch" : epoch ,
162176 "Train loss" : np .mean (trainingloss ),
163- "Evaluation Loss " : np .mean (evalloss ),
177+ "Validation loss " : np .mean (valloss ),
164178 }
165179 )
180+
181+ testloss = []
182+ model .eval ()
183+ with th .no_grad ():
184+ for x , y in tqdm (testloader , desc = "Testing" ):
185+ x , y = x .to (device ), y .to (device )
186+ logits = model .forward (x )
187+ loss = criterion (logits , y )
188+ testloss .append (loss .item ())
189+
190+ preds = th .argmax (logits , dim = 1 )
191+ metrics (y , preds )
166192
193+ wandb .log (metrics .accumulate (str_prefix = "Test " ))
194+ metrics .reset ()
195+ wandb .log ({"Test loss" : np .mean (testloss )})
167196
168197if __name__ == "__main__" :
169198 main ()
0 commit comments