Skip to content

Commit 3dd932f

Browse files
committed
select sample frequency with --sample_freq and checkpoint save frequency with --ckpt_freq
1 parent c5b6061 commit 3dd932f

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

main.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
flags.DEFINE_boolean("export", False, "True for exporting with new batch size")
2929
flags.DEFINE_boolean("freeze", False, "True for exporting with new batch size")
3030
flags.DEFINE_integer("max_to_keep", 1, "maximum number of checkpoints to keep")
31+
flags.DEFINE_integer("sample_freq", 200, "sample every this many iterations")
32+
flags.DEFINE_integer("ckpt_freq", 200, "save checkpoint every this many iterations")
3133
flags.DEFINE_integer("z_dim", 100, "dimensions of z")
3234
flags.DEFINE_string("z_dist", "uniform_signed", "'normal01' or 'uniform_unsigned' or uniform_signed")
3335
#flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")

model.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,11 @@ def train(self, config):
290290
errD_real = self.d_loss_real.eval({ self.inputs: batch_images })
291291
errG = self.g_loss.eval({self.z: batch_z})
292292

293-
counter += 1
294-
print("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
295-
% (epoch, config.epoch, idx, batch_idxs,
293+
print("[%8d Epoch:[%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
294+
% (counter, epoch, config.epoch, idx, batch_idxs,
296295
time.time() - start_time, errD_fake+errD_real, errG))
297296

298-
if np.mod(counter, 100) == 1:
297+
if np.mod(counter, config.sample_freq) == 0:
299298
if config.dataset == 'mnist':
300299
samples, d_loss, g_loss = self.sess.run(
301300
[self.sampler, self.d_loss, self.g_loss],
@@ -306,7 +305,7 @@ def train(self, config):
306305
}
307306
)
308307
save_images(samples, image_manifold_size(samples.shape[0]),
309-
'./{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
308+
'./{}/train_{:08d}.png'.format(config.sample_dir, counter))
310309
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))
311310
else:
312311
try:
@@ -318,14 +317,16 @@ def train(self, config):
318317
},
319318
)
320319
save_images(samples, image_manifold_size(samples.shape[0]),
321-
'./{}/train_{:02d}_{:04d}.png'.format(config.sample_dir, epoch, idx))
320+
'./{}/train_{:08d}.png'.format(config.sample_dir, counter))
322321
print("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss, g_loss))
323322
except:
324323
print("one pic error!...")
325324

326-
if np.mod(counter, 500) == 2:
325+
if np.mod(counter, config.ckpt_freq) == 0:
327326
self.save(config.checkpoint_dir, counter)
328-
327+
328+
counter += 1
329+
329330
def discriminator(self, image, y=None, reuse=False):
330331
with tf.variable_scope("discriminator") as scope:
331332
if reuse:

0 commit comments

Comments
 (0)