Skip to content

Commit 438b2c2

Browse files
committed
update ae
1 parent 63fa0b4 commit 438b2c2

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

examples/auto_encoder/train.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ def get_argparser():
3737

3838

3939
def main():
40+
if not os.path.exists('results'):
41+
os.mkdir('results')
42+
4043
opts = get_argparser().parse_args()
4144

4245
# dataset
@@ -58,7 +61,7 @@ def main():
5861

5962
val_loader = DataLoader(
6063
ImageDataset(root='datasets/CLIC/valid', transform=val_transform),
61-
batch_size=opts.batch_size, shuffle=False, num_workers=0)
64+
batch_size=1, shuffle=False, num_workers=0)
6265

6366
print("Train set: %d, Val set: %d" %
6467
(len(train_loader.dataset), len(val_loader.dataset)))
@@ -105,7 +108,7 @@ def main():
105108
# ===== Validation =====
106109
print("Val...")
107110
best_score = 0.0
108-
cur_score = test(opts, model, val_loader)
111+
cur_score = test(opts, model, val_loader, cur_epoch)
109112
print("%s = %.6f" % (opts.loss_type, cur_score))
110113
# ===== Save Best Model =====
111114
if cur_score > best_score: # save best model
@@ -114,7 +117,10 @@ def main():
114117
print("Best model saved as best_model.pt")
115118

116119

117-
def test(opts, model, val_loader):
120+
def test(opts, model, val_loader, epoch):
121+
save_dir = os.path.join('results', 'epoch_%d' % epoch)
122+
if not os.path.exists(save_dir):
123+
os.mkdir(save_dir)
118124
model.eval()
119125
cur_score = 0.0
120126

@@ -124,11 +130,10 @@ def test(opts, model, val_loader):
124130
for i, (images, ) in enumerate(val_loader):
125131
outputs = model(images)
126132
# save the first reconstructed image
127-
if i == 20:
128-
Image.fromarray((outputs*255).squeeze(0).detach().numpy().astype(
129-
'uint8').transpose(1, 2, 0)).save('recons_%s.png' % (opts.loss_type))
130133
cur_score += metric(outputs, images, data_range=1.0)
134+
Image.fromarray((outputs*255).squeeze(0).detach().numpy().astype('uint8').transpose(1, 2, 0)).save(os.path.join(save_dir, 'recons_%s_%d.png' % (opts.loss_type, i)))
131135
cur_score /= len(val_loader.dataset)
136+
132137
return cur_score
133138

134139

0 commit comments

Comments
 (0)