11from __future__ import division
2+ from __future__ import print_function
23import os
34import time
45import math
1314def conv_out_size_same (size , stride ):
1415 return int (math .ceil (float (size ) / float (stride )))
1516
17+ def gen_random (mode , size ):
18+ if mode == 'normal01' : return np .random .normal (0 ,1 ,size = size )
19+ if mode == 'uniform_signed' : return np .random .uniform (- 1 ,1 ,size = size )
20+ if mode == 'uniform_unsigned' : return np .random .uniform (0 ,1 ,size = size )
21+
22+
1623class DCGAN (object ):
1724 def __init__ (self , sess , input_height = 108 , input_width = 108 , crop = True ,
1825 batch_size = 64 , sample_num = 64 , output_height = 64 , output_width = 64 ,
1926 y_dim = None , z_dim = 100 , gf_dim = 64 , df_dim = 64 ,
2027 gfc_dim = 1024 , dfc_dim = 1024 , c_dim = 3 , dataset_name = 'default' ,
21- input_fname_pattern = '*.jpg' , checkpoint_dir = None , sample_dir = None , data_dir = './data' ):
28+ max_to_keep = 1 ,
29+ input_fname_pattern = '*.jpg' , checkpoint_dir = 'ckpts' , sample_dir = 'samples' , out_dir = './out' , data_dir = './data' ):
2230 """
2331
2432 Args:
@@ -70,6 +78,8 @@ def __init__(self, sess, input_height=108, input_width=108, crop=True,
7078 self .input_fname_pattern = input_fname_pattern
7179 self .checkpoint_dir = checkpoint_dir
7280 self .data_dir = data_dir
81+ self .out_dir = out_dir
82+ self .max_to_keep = max_to_keep
7383
7484 if self .dataset_name == 'mnist' :
7585 self .data_X , self .data_y = self .load_mnist ()
@@ -148,7 +158,7 @@ def sigmoid_cross_entropy_with_logits(x, y):
148158 self .d_vars = [var for var in t_vars if 'd_' in var .name ]
149159 self .g_vars = [var for var in t_vars if 'g_' in var .name ]
150160
151- self .saver = tf .train .Saver ()
161+ self .saver = tf .train .Saver (max_to_keep = self . max_to_keep )
152162
153163 def train (self , config ):
154164 d_optim = tf .train .AdamOptimizer (config .learning_rate , beta1 = config .beta1 ) \
@@ -160,13 +170,15 @@ def train(self, config):
160170 except :
161171 tf .initialize_all_variables ().run ()
162172
163- self .g_sum = merge_summary ([self .z_sum , self .d__sum ,
164- self .G_sum , self .d_loss_fake_sum , self .g_loss_sum ])
173+ if config .G_img_sum :
174+ self .g_sum = merge_summary ([self .z_sum , self .d__sum , self .G_sum , self .d_loss_fake_sum , self .g_loss_sum ])
175+ else :
176+ self .g_sum = merge_summary ([self .z_sum , self .d__sum , self .d_loss_fake_sum , self .g_loss_sum ])
165177 self .d_sum = merge_summary (
166178 [self .z_sum , self .d_sum , self .d_loss_real_sum , self .d_loss_sum ])
167- self .writer = SummaryWriter ("./ logs" , self .sess .graph )
179+ self .writer = SummaryWriter (os . path . join ( self . out_dir , " logs") , self .sess .graph )
168180
169- sample_z = np . random . uniform ( - 1 , 1 , size = (self .sample_num , self .z_dim ))
181+ sample_z = gen_random ( config . z_dist , size = (self .sample_num , self .z_dim ))
170182
171183 if config .dataset == 'mnist' :
172184 sample_inputs = self .data_X [0 :self .sample_num ]
@@ -223,7 +235,7 @@ def train(self, config):
223235 else :
224236 batch_images = np .array (batch ).astype (np .float32 )
225237
226- batch_z = np . random . uniform ( - 1 , 1 , [config .batch_size , self .z_dim ]) \
238+ batch_z = gen_random ( config . z_dist , size = [config .batch_size , self .z_dim ]) \
227239 .astype (np .float32 )
228240
229241 if config .dataset == 'mnist' :
@@ -281,12 +293,11 @@ def train(self, config):
281293 errD_real = self .d_loss_real .eval ({ self .inputs : batch_images })
282294 errG = self .g_loss .eval ({self .z : batch_z })
283295
284- counter += 1
285- print ("Epoch: [%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
286- % (epoch , config .epoch , idx , batch_idxs ,
296+ print ("[%8d Epoch:[%2d/%2d] [%4d/%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f" \
297+ % (counter , epoch , config .epoch , idx , batch_idxs ,
287298 time .time () - start_time , errD_fake + errD_real , errG ))
288299
289- if np .mod (counter , 100 ) == 1 :
300+ if np .mod (counter , config . sample_freq ) == 0 :
290301 if config .dataset == 'mnist' :
291302 samples , d_loss , g_loss = self .sess .run (
292303 [self .sampler , self .d_loss , self .g_loss ],
@@ -297,7 +308,7 @@ def train(self, config):
297308 }
298309 )
299310 save_images (samples , image_manifold_size (samples .shape [0 ]),
300- './{}/train_{:02d}_{:04d} .png' .format (config .sample_dir , epoch , idx ))
311+ './{}/train_{:08d} .png' .format (config .sample_dir , counter ))
301312 print ("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss , g_loss ))
302313 else :
303314 try :
@@ -309,14 +320,16 @@ def train(self, config):
309320 },
310321 )
311322 save_images (samples , image_manifold_size (samples .shape [0 ]),
312- './{}/train_{:02d}_{:04d} .png' .format (config .sample_dir , epoch , idx ))
323+ './{}/train_{:08d} .png' .format (config .sample_dir , counter ))
313324 print ("[Sample] d_loss: %.8f, g_loss: %.8f" % (d_loss , g_loss ))
314325 except :
315326 print ("one pic error!..." )
316327
317- if np .mod (counter , 500 ) == 2 :
328+ if np .mod (counter , config . ckpt_freq ) == 0 :
318329 self .save (config .checkpoint_dir , counter )
319-
330+
331+ counter += 1
332+
320333 def discriminator (self , image , y = None , reuse = False ):
321334 with tf .variable_scope ("discriminator" ) as scope :
322335 if reuse :
@@ -501,28 +514,39 @@ def model_dir(self):
501514 return "{}_{}_{}_{}" .format (
502515 self .dataset_name , self .batch_size ,
503516 self .output_height , self .output_width )
504-
505- def save (self , checkpoint_dir , step ):
506- model_name = "DCGAN.model"
507- checkpoint_dir = os .path .join (checkpoint_dir , self .model_dir )
508517
518+ def save (self , checkpoint_dir , step , filename = 'model' , ckpt = True , frozen = False ):
519+ # model_name = "DCGAN.model"
520+ # checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
521+
522+ filename += '.b' + str (self .batch_size )
509523 if not os .path .exists (checkpoint_dir ):
510524 os .makedirs (checkpoint_dir )
511525
512- self .saver .save (self .sess ,
513- os .path .join (checkpoint_dir , model_name ),
514- global_step = step )
526+ if ckpt :
527+ self .saver .save (self .sess ,
528+ os .path .join (checkpoint_dir , filename ),
529+ global_step = step )
530+
531+ if frozen :
532+ tf .train .write_graph (
533+ tf .graph_util .convert_variables_to_constants (self .sess , self .sess .graph_def , ["generator_1/Tanh" ]),
534+ checkpoint_dir ,
535+ '{}-{:06d}_frz.pb' .format (filename , step ),
536+ as_text = False )
515537
516538 def load (self , checkpoint_dir ):
517- import re
518- print (" [*] Reading checkpoints..." )
519- checkpoint_dir = os .path .join (checkpoint_dir , self .model_dir )
539+ #import re
540+ print (" [*] Reading checkpoints..." , checkpoint_dir )
541+ # checkpoint_dir = os.path.join(checkpoint_dir, self.model_dir)
542+ # print(" ->", checkpoint_dir)
520543
521544 ckpt = tf .train .get_checkpoint_state (checkpoint_dir )
522545 if ckpt and ckpt .model_checkpoint_path :
523546 ckpt_name = os .path .basename (ckpt .model_checkpoint_path )
524547 self .saver .restore (self .sess , os .path .join (checkpoint_dir , ckpt_name ))
525- counter = int (next (re .finditer ("(\d+)(?!.*\d)" ,ckpt_name )).group (0 ))
548+ #counter = int(next(re.finditer("(\d+)(?!.*\d)",ckpt_name)).group(0))
549+ counter = int (ckpt_name .split ('-' )[- 1 ])
526550 print (" [*] Success to read {}" .format (ckpt_name ))
527551 return True , counter
528552 else :
0 commit comments