Skip to content

Commit c5b6061

Browse files
committed
set maximum number of checkpoints to keep with --max_to_keep
1 parent ec8d0e5 commit c5b6061

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

main.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
2828
flags.DEFINE_boolean("export", False, "True for exporting with new batch size")
2929
flags.DEFINE_boolean("freeze", False, "True for exporting with new batch size")
30+
flags.DEFINE_integer("max_to_keep", 1, "maximum number of checkpoints to keep")
3031
flags.DEFINE_integer("z_dim", 100, "dimensions of z")
3132
flags.DEFINE_string("z_dist", "uniform_signed", "'normal01' or 'uniform_unsigned' or uniform_signed")
3233
#flags.DEFINE_integer("generate_test_images", 100, "Number of images to generate during test. [100]")
@@ -66,7 +67,8 @@ def main(_):
6667
crop=FLAGS.crop,
6768
checkpoint_dir=FLAGS.checkpoint_dir,
6869
sample_dir=FLAGS.sample_dir,
69-
data_dir=FLAGS.data_dir)
70+
data_dir=FLAGS.data_dir,
71+
max_to_keep=FLAGS.max_to_keep)
7072
else:
7173
dcgan = DCGAN(
7274
sess,
@@ -82,7 +84,8 @@ def main(_):
8284
crop=FLAGS.crop,
8385
checkpoint_dir=FLAGS.checkpoint_dir,
8486
sample_dir=FLAGS.sample_dir,
85-
data_dir=FLAGS.data_dir)
87+
data_dir=FLAGS.data_dir,
88+
max_to_keep=FLAGS.max_to_keep)
8689

8790
show_all_variables()
8891

model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
2525
batch_size=64, sample_num = 64, output_height=64, output_width=64,
2626
y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
2727
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
28+
max_to_keep=1,
2829
input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None, data_dir='./data'):
2930
"""
3031
@@ -77,6 +78,7 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
7778
self.input_fname_pattern = input_fname_pattern
7879
self.checkpoint_dir = checkpoint_dir
7980
self.data_dir = data_dir
81+
self.max_to_keep = max_to_keep
8082

8183
if self.dataset_name == 'mnist':
8284
self.data_X, self.data_y = self.load_mnist()
@@ -155,7 +157,7 @@ def sigmoid_cross_entropy_with_logits(x, y):
155157
self.d_vars = [var for var in t_vars if 'd_' in var.name]
156158
self.g_vars = [var for var in t_vars if 'g_' in var.name]
157159

158-
self.saver = tf.train.Saver()
160+
self.saver = tf.train.Saver(max_to_keep=self.max_to_keep)
159161

160162
def train(self, config):
161163
d_optim = tf.train.AdamOptimizer(config.learning_rate, beta1=config.beta1) \

0 commit comments

Comments
 (0)