3838add_arg ('--scales' , default = 2 , type = int , help = 'How many times to perform 2x upsampling.' )
3939add_arg ('--model' , default = 'small' , type = str , help = 'Name of the neural network to load/save.' )
4040add_arg ('--train' , default = False , type = str , help = 'File pattern to load for training.' )
41- add_arg ('--batch-resolution' , default = 192 , type = int , help = 'Resolution of images in training batch.' )
41+ add_arg ('--epochs' , default = 10 , type = int , help = 'Total number of iterations in training.' )
42+ add_arg ('--epoch-size' , default = 72 , type = int , help = 'Number of batches trained in an epoch.' )
43+ add_arg ('--save-every' , default = 10 , type = int , help = 'Save generator after every training epoch.' )
44+ add_arg ('--batch-shape' , default = 192 , type = int , help = 'Resolution of images in training batch.' )
4245add_arg ('--batch-size' , default = 15 , type = int , help = 'Number of images per training batch.' )
4346add_arg ('--buffer-size' , default = 1500 , type = int , help = 'Total image fragments kept in cache.' )
4447add_arg ('--buffer-similar' , default = 5 , type = int , help = 'Fragments cached for each image loaded.' )
45- add_arg ('--epochs' , default = 10 , type = int , help = 'Total number of iterations in training.' )
46- add_arg ('--epoch-size' , default = 72 , type = int , help = 'Number of batches trained in an epoch.' )
4748add_arg ('--learning-rate' , default = 1E-4 , type = float , help = 'Parameter for the ADAM optimizer.' )
4849add_arg ('--learning-period' , default = 50 , type = int , help = 'How often to decay the learning rate.' )
4950add_arg ('--learning-decay' , default = 0.5 , type = float , help = 'How much to decay the learning rate.' )
@@ -128,8 +129,10 @@ def __init__(self):
128129 self .data_ready = threading .Event ()
129130 self .data_copied = threading .Event ()
130131
131- self .resolution = args .batch_resolution
132- self .buffer = np .zeros ((args .buffer_size , 3 , self .resolution , self .resolution ), dtype = np .float32 )
132+ self .orig_shape , self .seed_shape = args .batch_shape , int (args .batch_shape / 2 ** args .scales )
133+
134+ self .orig_buffer = np .zeros ((args .buffer_size , 3 , self .orig_shape , self .orig_shape ), dtype = np .float32 )
135+ self .seed_buffer = np .zeros ((args .buffer_size , 3 , self .seed_shape , self .seed_shape ), dtype = np .float32 )
133136 self .files = glob .glob (args .train )
134137 if len (self .files ) == 0 :
135138 error ("There were no files found to train from searching for `{}`" .format (args .train ),
@@ -157,27 +160,31 @@ def run(self):
157160
158161 for _ in range (args .buffer_similar ):
159162 copy = img [:,::- 1 ] if random .choice ([True , False ]) else img
160- h = random .randint (0 , copy .shape [0 ] - self .resolution )
161- w = random .randint (0 , copy .shape [1 ] - self .resolution )
162- copy = copy [h :h + self .resolution , w :w + self .resolution ]
163+ h = random .randint (0 , copy .shape [0 ] - self .orig_shape )
164+ w = random .randint (0 , copy .shape [1 ] - self .orig_shape )
165+ copy = copy [h :h + self .orig_shape , w :w + self .orig_shape ]
163166
164167 while len (self .available ) == 0 :
165168 self .data_copied .wait ()
166169 self .data_copied .clear ()
167170
168171 i = self .available .pop ()
169- self .buffer [i ] = np .transpose (copy / 255.0 - 0.5 , (2 , 0 , 1 ))
172+ self .orig_buffer [i ] = np .transpose (copy / 255.0 - 0.5 , (2 , 0 , 1 ))
173+ seed = scipy .misc .imresize (copy , size = (self .seed_shape , self .seed_shape ), interp = 'bilinear' )
174+ self .seed_buffer [i ] = np .transpose (seed / 255.0 - 0.5 , (2 , 0 , 1 ))
170175 self .ready .add (i )
171176
172177 if len (self .ready ) >= args .batch_size :
173178 self .data_ready .set ()
174179
175- def copy (self , output ):
180+ def copy (self , origs_out , seeds_out ):
176181 self .data_ready .wait ()
177182 self .data_ready .clear ()
178183
179184 for i , j in enumerate (random .sample (self .ready , args .batch_size )):
180- output [i ] = self .buffer [j ]
185+ origs_out [i ] = self .orig_buffer [j ]
186+ seeds_out [i ] = self .seed_buffer [j ]
187+
181188 self .available .add (j )
182189
183190 self .data_copied .set ()
@@ -211,12 +218,8 @@ class Model(object):
211218
212219 def __init__ (self ):
213220 self .network = collections .OrderedDict ()
214- if args .train :
215- self .network ['img' ] = InputLayer ((None , 3 , None , None ))
216- self .network ['seed' ] = PoolLayer (self .network ['img' ], pool_size = 2 ** args .scales , mode = 'average_exc_pad' )
217- else :
218- self .network ['img' ] = InputLayer ((None , 3 , None , None ))
219- self .network ['seed' ] = self .network ['img' ]
221+ self .network ['img' ] = InputLayer ((None , 3 , None , None ))
222+ self .network ['seed' ] = InputLayer ((None , 3 , None , None ))
220223
221224 config , params = self .load_model ()
222225 self .setup_generator (self .last_layer (), config )
@@ -378,10 +381,10 @@ def loss_discriminator(self, d):
378381
379382 def compile (self ):
380383 # Helper function for rendering test images during training, or standalone non-training mode.
381- input_tensor = T .tensor4 ()
382- input_layers = {self .network ['img' ]: input_tensor }
383- output = lasagne .layers .get_output ([self .network [k ] for k in ['img' , ' seed' , 'out' ]], input_layers , deterministic = True )
384- self .predict = theano .function ([input_tensor ], output )
384+ input_tensor , seed_tensor = T . tensor4 (), T .tensor4 ()
385+ input_layers = {self .network ['img' ]: input_tensor , self . network [ 'seed' ]: seed_tensor }
386+ output = lasagne .layers .get_output ([self .network [k ] for k in ['seed' , 'out' ]], input_layers , deterministic = True )
387+ self .predict = theano .function ([seed_tensor ], output )
385388
386389 if not args .train : return
387390
@@ -407,7 +410,7 @@ def compile(self):
407410
408411 # Combined Theano function for updating both generator and discriminator at the same time.
409412 updates = collections .OrderedDict (list (gen_updates .items ()) + list (disc_updates .items ()))
410- self .fit = theano .function ([input_tensor ], gen_losses + [disc_out .mean (axis = (1 ,2 ,3 ))], updates = updates )
413+ self .fit = theano .function ([input_tensor , seed_tensor ], gen_losses + [disc_out .mean (axis = (1 ,2 ,3 ))], updates = updates )
411414
412415
413416
@@ -448,29 +451,31 @@ def decay_learning_rate(self):
448451 if t_cur % args .learning_period == 0 : l_r *= args .learning_decay
449452
450453 def train (self ):
451- images = np .zeros ((args .batch_size , 3 , args .batch_resolution , args .batch_resolution ), dtype = np .float32 )
454+ seed_size = int (args .batch_shape / 2 ** args .scales )
455+ images = np .zeros ((args .batch_size , 3 , args .batch_shape , args .batch_shape ), dtype = np .float32 )
456+ seeds = np .zeros ((args .batch_size , 3 , seed_size , seed_size ), dtype = np .float32 )
452457 learning_rate = self .decay_learning_rate ()
453458 try :
454- running , start = None , time .time ()
459+ average , start = None , time .time ()
455460 for epoch in range (args .epochs ):
456461 total , stats = None , None
457462 l_r = next (learning_rate )
458463 if epoch >= args .generator_start : self .model .gen_lr .set_value (l_r )
459464 if epoch >= args .discriminator_start : self .model .disc_lr .set_value (l_r )
460465
461466 for _ in range (args .epoch_size ):
462- self .thread .copy (images )
463- output = self .model .fit (images )
467+ self .thread .copy (images , seeds )
468+ output = self .model .fit (images , seeds )
464469 losses = np .array (output [:3 ], dtype = np .float32 )
465470 stats = (stats + output [3 ]) if stats is not None else output [3 ]
466471 total = total + losses if total is not None else losses
467472 l = np .sum (losses )
468473 assert not np .isnan (losses ).any ()
469- running = l if running is None else running * 0.95 + 0.05 * l
470- print ('↑' if l > running else '↓' , end = '' , flush = True )
474+ average = l if average is None else average * 0.95 + 0.05 * l
475+ print ('↑' if l > average else '↓' , end = '' , flush = True )
471476
472- orign , scald , repro = self .model .predict (images )
473- self .show_progress (orign , scald , repro )
477+ scald , repro = self .model .predict (seeds )
478+ self .show_progress (images , scald , repro )
474479 total /= args .epoch_size
475480 stats /= args .epoch_size
476481 totals , labels = [sum (total )] + list (total ), ['total' , 'prcpt' , 'smthn' , 'advrs' ]
@@ -481,9 +486,12 @@ def train(self):
481486 real , fake = stats [:args .batch_size ], stats [args .batch_size :]
482487 print (' - discriminator' , real .mean (), len (np .where (real > 0.5 )[0 ]), fake .mean (), len (np .where (fake < - 0.5 )[0 ]))
483488 if epoch == args .adversarial_start - 1 :
484- print (' - adversary mode: generator engaging discriminator.' )
489+ print (' - generator now optimizing against discriminator.' )
485490 self .model .adversary_weight .set_value (args .adversary_weight )
486491 running = None
492+ if (epoch + 1 ) % args .save_every == 0 :
493+ print (' - saving current generator layers to disk...' )
494+ self .model .save_generator ()
487495
488496 except KeyboardInterrupt :
489497 pass
@@ -505,11 +513,9 @@ def process(self, image):
505513
506514 if args .train :
507515 enhancer .train ()
508-
509- for filename in args .files :
510- print (filename )
511- out = enhancer .process (scipy .ndimage .imread (filename , mode = 'RGB' ))
512- out .save (os .path .splitext (filename )[0 ]+ '_ne%ix.png' % (2 ** args .scales ))
513-
514- if args .files :
516+ else :
517+ for filename in args .files :
518+ print (filename )
519+ out = enhancer .process (scipy .ndimage .imread (filename , mode = 'RGB' ))
520+ out .save (os .path .splitext (filename )[0 ]+ '_ne%ix.png' % (2 ** args .scales ))
515521 print (ansi .ENDC )
0 commit comments