11import os
22import scipy .misc
33import numpy as np
4+ import json
45
56from model import DCGAN
6- from utils import pp , visualize , to_json , show_all_variables
7+ from utils import pp , visualize , to_json , show_all_variables , expand_path , timestamp
78
89import tensorflow as tf
910
1920flags .DEFINE_integer ("output_width" , None , "The size of the output images to produce. If None, same value as output_height [None]" )
2021flags .DEFINE_string ("dataset" , "celebA" , "The name of dataset [celebA, mnist, lsun]" )
2122flags .DEFINE_string ("input_fname_pattern" , "*.jpg" , "Glob pattern of filename of input images [*]" )
22- flags .DEFINE_string ("checkpoint_dir" , "checkpoint" , "Directory name to save the checkpoints [checkpoint]" )
23- flags .DEFINE_string ("data_dir" , "./data" , "Root directory of dataset [data]" )
24- flags .DEFINE_string ("sample_dir" , "samples" , "Directory name to save the image samples [samples]" )
23+ flags .DEFINE_string ("data_dir" , "$HOME/data" , "path to datasets [$HOME/data]" )
24+ flags .DEFINE_string ("out_dir" , "$HOME/out" , "Root directory for outputs [$HOME/out]" )
25+ flags .DEFINE_string ("out_name" , "" , "Folder (under out_root_dir) for all outputs. Generated automatically if left blank []" )
26+ flags .DEFINE_string ("checkpoint_dir" , "checkpoint" , "Folder (under out_root_dir/out_name) to save checkpoints [checkpoint]" )
27+ flags .DEFINE_string ("sample_dir" , "samples" , "Folder (under out_root_dir/out_name) to save samples [samples]" )
2528flags .DEFINE_boolean ("train" , False , "True for training, False for testing [False]" )
2629flags .DEFINE_boolean ("crop" , False , "True for training, False for testing [False]" )
2730flags .DEFINE_boolean ("visualize" , False , "True for visualizing, False for nothing [False]" )
3740
3841def main (_ ):
3942 pp .pprint (flags .FLAGS .__flags )
40-
41- if FLAGS .input_width is None :
42- FLAGS .input_width = FLAGS .input_height
43- if FLAGS .output_width is None :
44- FLAGS .output_width = FLAGS .output_height
45-
46- if not os .path .exists (FLAGS .checkpoint_dir ):
47- os .makedirs (FLAGS .checkpoint_dir )
48- if not os .path .exists (FLAGS .sample_dir ):
49- os .makedirs (FLAGS .sample_dir )
43+
44+ # expand user name and environment variables
45+ FLAGS .data_dir = expand_path (FLAGS .data_dir )
46+ FLAGS .out_dir = expand_path (FLAGS .out_dir )
47+ FLAGS .out_name = expand_path (FLAGS .out_name )
48+ FLAGS .checkpoint_dir = expand_path (FLAGS .checkpoint_dir )
49+ FLAGS .sample_dir = expand_path (FLAGS .sample_dir )
50+
51+ if FLAGS .output_height is None : FLAGS .output_height = FLAGS .input_height
52+ if FLAGS .input_width is None : FLAGS .input_width = FLAGS .input_height
53+ if FLAGS .output_width is None : FLAGS .output_width = FLAGS .output_height
54+
55+ # output folders
56+ if FLAGS .out_name == "" :
57+ FLAGS .out_name = '{} - {} - {}' .format (timestamp (), FLAGS .data_dir .split ('/' )[- 1 ], FLAGS .dataset ) # penultimate folder of path
58+ if FLAGS .train :
59+ FLAGS .out_name += ' - x{}.z{}.{}.y{}.b{}' .format (FLAGS .input_width , FLAGS .z_dim , FLAGS .z_dist , FLAGS .output_width , FLAGS .batch_size )
60+
61+ FLAGS .out_dir = os .path .join (FLAGS .out_dir , FLAGS .out_name )
62+ FLAGS .checkpoint_dir = os .path .join (FLAGS .out_dir , FLAGS .checkpoint_dir )
63+ FLAGS .sample_dir = os .path .join (FLAGS .out_dir , FLAGS .sample_dir )
64+
65+ if not os .path .exists (FLAGS .checkpoint_dir ): os .makedirs (FLAGS .checkpoint_dir )
66+ if not os .path .exists (FLAGS .sample_dir ): os .makedirs (FLAGS .sample_dir )
67+
68+ with open (os .path .join (FLAGS .out_dir , 'FLAGS.json' ), 'w' ) as f :
69+ flags_dict = {k :FLAGS [k ].value for k in FLAGS }
70+ json .dump (flags_dict , f , indent = 4 , sort_keys = True , ensure_ascii = False )
71+
5072
5173 #gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.333)
5274 run_config = tf .ConfigProto ()
@@ -70,6 +92,7 @@ def main(_):
7092 checkpoint_dir = FLAGS .checkpoint_dir ,
7193 sample_dir = FLAGS .sample_dir ,
7294 data_dir = FLAGS .data_dir ,
95+ out_dir = FLAGS .out_dir ,
7396 max_to_keep = FLAGS .max_to_keep )
7497 else :
7598 dcgan = DCGAN (
@@ -87,6 +110,7 @@ def main(_):
87110 checkpoint_dir = FLAGS .checkpoint_dir ,
88111 sample_dir = FLAGS .sample_dir ,
89112 data_dir = FLAGS .data_dir ,
113+ out_dir = FLAGS .out_dir ,
90114 max_to_keep = FLAGS .max_to_keep )
91115
92116 show_all_variables ()
0 commit comments