Skip to content

Commit c3b734b

Browse files
committed
output management:
- set root output directory with --out_dir. - set output name with --out_name (or leave blank for it to be generated automatically from hyperparams) - checkpoints, samples and log saved under out_dir/out_name - environment variables expanded - save json for FLAGS for reference - auto set output size to input size if None
1 parent 3dd932f commit c3b734b

File tree

3 files changed

+59
-22
lines changed

3 files changed

+59
-22
lines changed

main.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import os
22
import scipy.misc
33
import numpy as np
4+
import json
45

56
from model import DCGAN
6-
from utils import pp, visualize, to_json, show_all_variables
7+
from utils import pp, visualize, to_json, show_all_variables, expand_path, timestamp
78

89
import tensorflow as tf
910

@@ -19,9 +20,11 @@
1920
flags.DEFINE_integer("output_width", None, "The size of the output images to produce. If None, same value as output_height [None]")
2021
flags.DEFINE_string("dataset", "celebA", "The name of dataset [celebA, mnist, lsun]")
2122
flags.DEFINE_string("input_fname_pattern", "*.jpg", "Glob pattern of filename of input images [*]")
22-
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Directory name to save the checkpoints [checkpoint]")
23-
flags.DEFINE_string("data_dir", "./data", "Root directory of dataset [data]")
24-
flags.DEFINE_string("sample_dir", "samples", "Directory name to save the image samples [samples]")
23+
flags.DEFINE_string("data_dir", "$HOME/data", "path to datasets [$HOME/data]")
24+
flags.DEFINE_string("out_dir", "$HOME/out", "Root directory for outputs [$HOME/out]")
25+
flags.DEFINE_string("out_name", "", "Folder (under out_root_dir) for all outputs. Generated automatically if left blank []")
26+
flags.DEFINE_string("checkpoint_dir", "checkpoint", "Folder (under out_root_dir/out_name) to save checkpoints [checkpoint]")
27+
flags.DEFINE_string("sample_dir", "samples", "Folder (under out_root_dir/out_name) to save samples [samples]")
2528
flags.DEFINE_boolean("train", False, "True for training, False for testing [False]")
2629
flags.DEFINE_boolean("crop", False, "True for training, False for testing [False]")
2730
flags.DEFINE_boolean("visualize", False, "True for visualizing, False for nothing [False]")
@@ -37,16 +40,35 @@
3740

3841
def main(_):
3942
pp.pprint(flags.FLAGS.__flags)
40-
41-
if FLAGS.input_width is None:
42-
FLAGS.input_width = FLAGS.input_height
43-
if FLAGS.output_width is None:
44-
FLAGS.output_width = FLAGS.output_height
45-
46-
if not os.path.exists(FLAGS.checkpoint_dir):
47-
os.makedirs(FLAGS.checkpoint_dir)
48-
if not os.path.exists(FLAGS.sample_dir):
49-
os.makedirs(FLAGS.sample_dir)
43+
44+
# expand user name and environment variables
45+
FLAGS.data_dir = expand_path(FLAGS.data_dir)
46+
FLAGS.out_dir = expand_path(FLAGS.out_dir)
47+
FLAGS.out_name = expand_path(FLAGS.out_name)
48+
FLAGS.checkpoint_dir = expand_path(FLAGS.checkpoint_dir)
49+
FLAGS.sample_dir = expand_path(FLAGS.sample_dir)
50+
51+
if FLAGS.output_height is None: FLAGS.output_height = FLAGS.input_height
52+
if FLAGS.input_width is None: FLAGS.input_width = FLAGS.input_height
53+
if FLAGS.output_width is None: FLAGS.output_width = FLAGS.output_height
54+
55+
# output folders
56+
if FLAGS.out_name == "":
57+
FLAGS.out_name = '{} - {} - {}'.format(timestamp(), FLAGS.data_dir.split('/')[-1], FLAGS.dataset) # penultimate folder of path
58+
if FLAGS.train:
59+
FLAGS.out_name += ' - x{}.z{}.{}.y{}.b{}'.format(FLAGS.input_width, FLAGS.z_dim, FLAGS.z_dist, FLAGS.output_width, FLAGS.batch_size)
60+
61+
FLAGS.out_dir = os.path.join(FLAGS.out_dir, FLAGS.out_name)
62+
FLAGS.checkpoint_dir = os.path.join(FLAGS.out_dir, FLAGS.checkpoint_dir)
63+
FLAGS.sample_dir = os.path.join(FLAGS.out_dir, FLAGS.sample_dir)
64+
65+
if not os.path.exists(FLAGS.checkpoint_dir): os.makedirs(FLAGS.checkpoint_dir)
66+
if not os.path.exists(FLAGS.sample_dir): os.makedirs(FLAGS.sample_dir)
67+
68+
with open(os.path.join(FLAGS.out_dir, 'FLAGS.json'), 'w') as f:
69+
flags_dict = {k:FLAGS[k].value for k in FLAGS}
70+
json.dump(flags_dict, f, indent=4, sort_keys=True, ensure_ascii=False)
71+
5072

5173
#gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
5274
run_config = tf.ConfigProto()
@@ -70,6 +92,7 @@ def main(_):
7092
checkpoint_dir=FLAGS.checkpoint_dir,
7193
sample_dir=FLAGS.sample_dir,
7294
data_dir=FLAGS.data_dir,
95+
out_dir=FLAGS.out_dir,
7396
max_to_keep=FLAGS.max_to_keep)
7497
else:
7598
dcgan = DCGAN(
@@ -87,6 +110,7 @@ def main(_):
87110
checkpoint_dir=FLAGS.checkpoint_dir,
88111
sample_dir=FLAGS.sample_dir,
89112
data_dir=FLAGS.data_dir,
113+
out_dir=FLAGS.out_dir,
90114
max_to_keep=FLAGS.max_to_keep)
91115

92116
show_all_variables()

model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
2626
y_dim=None, z_dim=100, gf_dim=64, df_dim=64,
2727
gfc_dim=1024, dfc_dim=1024, c_dim=3, dataset_name='default',
2828
max_to_keep=1,
29-
input_fname_pattern='*.jpg', checkpoint_dir=None, sample_dir=None, data_dir='./data'):
29+
input_fname_pattern='*.jpg', checkpoint_dir='ckpts', sample_dir='samples', out_dir='./out', data_dir='./data'):
3030
"""
3131
3232
Args:
@@ -78,6 +78,7 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
7878
self.input_fname_pattern = input_fname_pattern
7979
self.checkpoint_dir = checkpoint_dir
8080
self.data_dir = data_dir
81+
self.out_dir = out_dir
8182
self.max_to_keep = max_to_keep
8283

8384
if self.dataset_name == 'mnist':
@@ -173,7 +174,7 @@ def train(self, config):
173174
self.G_sum, self.d_loss_fake_sum, self.g_loss_sum])
174175
self.d_sum = merge_summary(
175176
[self.z_sum, self.d_sum, self.d_loss_real_sum, self.d_loss_sum])
176-
self.writer = SummaryWriter("./logs", self.sess.graph)
177+
self.writer = SummaryWriter(os.path.join(self.out_dir, "logs"), self.sess.graph)
177178

178179
sample_z = gen_random(config.z_dist, size=(self.sample_num , self.z_dim))
179180

utils.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
import pprint
99
import scipy.misc
1010
import numpy as np
11+
import os
12+
import time
13+
import datetime
1114
from time import gmtime, strftime
1215
from six.moves import xrange
1316

@@ -18,6 +21,15 @@
1821

1922
get_stddev = lambda x, k_h, k_w: 1/math.sqrt(k_w*k_h*x.get_shape()[-1])
2023

24+
25+
def expand_path(path):
26+
return os.path.expanduser(os.path.expandvars(path))
27+
28+
def timestamp(s='%Y%m%d.%H%M%S', ts=None):
29+
if not ts: ts = time.time()
30+
st = datetime.datetime.fromtimestamp(ts).strftime(s)
31+
return st
32+
2133
def show_all_variables():
2234
model_vars = tf.trainable_variables()
2335
slim.model_analyzer.analyze_vars(model_vars, print_info=True)
@@ -169,12 +181,12 @@ def make_frame(t):
169181
clip = mpy.VideoClip(make_frame, duration=duration)
170182
clip.write_gif(fname, fps = len(images) / duration)
171183

172-
def visualize(sess, dcgan, config, option):
184+
def visualize(sess, dcgan, config, option, sample_dir='samples'):
173185
image_frame_dim = int(math.ceil(config.batch_size**.5))
174186
if option == 0:
175187
z_sample = np.random.uniform(-0.5, 0.5, size=(config.batch_size, dcgan.z_dim))
176188
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
177-
save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime()))
189+
save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime() )))
178190
elif option == 1:
179191
values = np.arange(0, 1, 1./config.batch_size)
180192
for idx in xrange(dcgan.z_dim):
@@ -192,7 +204,7 @@ def visualize(sess, dcgan, config, option):
192204
else:
193205
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
194206

195-
save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_arange_%s.png' % (idx))
207+
save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_arange_%s.png' % (idx)))
196208
elif option == 2:
197209
values = np.arange(0, 1, 1./config.batch_size)
198210
for idx in [random.randint(0, dcgan.z_dim - 1) for _ in xrange(dcgan.z_dim)]:
@@ -215,7 +227,7 @@ def visualize(sess, dcgan, config, option):
215227
try:
216228
make_gif(samples, './samples/test_gif_%s.gif' % (idx))
217229
except:
218-
save_images(samples, [image_frame_dim, image_frame_dim], './samples/test_%s.png' % strftime("%Y-%m-%d-%H-%M-%S", gmtime()))
230+
save_images(samples, [image_frame_dim, image_frame_dim], os.path.join(sample_dir, 'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime() )))
219231
elif option == 3:
220232
values = np.arange(0, 1, 1./config.batch_size)
221233
for idx in xrange(dcgan.z_dim):
@@ -225,7 +237,7 @@ def visualize(sess, dcgan, config, option):
225237
z[idx] = values[kdx]
226238

227239
samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
228-
make_gif(samples, './samples/test_gif_%s.gif' % (idx))
240+
make_gif(samples, os.path.join(sample_dir, 'test_gif_%s.gif' % (idx)))
229241
elif option == 4:
230242
image_set = []
231243
values = np.arange(0, 1, 1./config.batch_size)
@@ -236,7 +248,7 @@ def visualize(sess, dcgan, config, option):
236248
for kdx, z in enumerate(z_sample): z[idx] = values[kdx]
237249

238250
image_set.append(sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample}))
239-
make_gif(image_set[-1], './samples/test_gif_%s.gif' % (idx))
251+
make_gif(image_set[-1], os.path.join(sample_dir, 'test_gif_%s.gif' % (idx)))
240252

241253
new_image_set = [merge(np.array([images[idx] for images in image_set]), [10, 10]) \
242254
for idx in range(64) + range(63, -1, -1)]

0 commit comments

Comments
 (0)