@@ -37,6 +37,9 @@ def get_argparser():
3737
3838
3939def main ():
40+ if not os .path .exists ('results' ):
41+ os .mkdir ('results' )
42+
4043 opts = get_argparser ().parse_args ()
4144
4245 # dataset
@@ -58,7 +61,7 @@ def main():
5861
5962 val_loader = DataLoader (
6063 ImageDataset (root = 'datasets/CLIC/valid' , transform = val_transform ),
61- batch_size = opts . batch_size , shuffle = False , num_workers = 0 )
64+ batch_size = 1 , shuffle = False , num_workers = 0 )
6265
6366 print ("Train set: %d, Val set: %d" %
6467 (len (train_loader .dataset ), len (val_loader .dataset )))
@@ -105,7 +108,7 @@ def main():
105108 # ===== Validation =====
106109 print ("Val..." )
107110 best_score = 0.0
108- cur_score = test (opts , model , val_loader )
111+ cur_score = test (opts , model , val_loader , cur_epoch )
109112 print ("%s = %.6f" % (opts .loss_type , cur_score ))
110113 # ===== Save Best Model =====
111114 if cur_score > best_score : # save best model
@@ -114,7 +117,10 @@ def main():
114117 print ("Best model saved as best_model.pt" )
115118
116119
117- def test (opts , model , val_loader ):
120+ def test (opts , model , val_loader , epoch ):
121+ save_dir = os .path .join ('results' , 'epoch_%d' % epoch )
122+ if not os .path .exists (save_dir ):
123+ os .mkdir (save_dir )
118124 model .eval ()
119125 cur_score = 0.0
120126
@@ -124,11 +130,10 @@ def test(opts, model, val_loader):
124130 for i , (images , ) in enumerate (val_loader ):
125131 outputs = model (images )
126132 # save the first reconstructed image
127- if i == 20 :
128- Image .fromarray ((outputs * 255 ).squeeze (0 ).detach ().numpy ().astype (
129- 'uint8' ).transpose (1 , 2 , 0 )).save ('recons_%s.png' % (opts .loss_type ))
130133 cur_score += metric (outputs , images , data_range = 1.0 )
134+ Image .fromarray ((outputs * 255 ).squeeze (0 ).detach ().numpy ().astype ('uint8' ).transpose (1 , 2 , 0 )).save (os .path .join (save_dir , 'recons_%s_%d.png' % (opts .loss_type , i )))
131135 cur_score /= len (val_loader .dataset )
136+
132137 return cur_score
133138
134139
0 commit comments