3838add_arg ('--scales' , default = 2 , type = int , help = 'How many times to perform 2x upsampling.' )
3939add_arg ('--model' , default = 'small' , type = str , help = 'Name of the neural network to load/save.' )
4040add_arg ('--train' , default = False , type = str , help = 'File pattern to load for training.' )
41- add_arg ('--save-every-epoch' , default = False , action = 'store_true' , help = 'Save generator after every training epoch.' )
42- add_arg ('--batch-resolution' , default = 192 , type = int , help = 'Resolution of images in training batch.' )
41+ add_arg ('--epochs' , default = 10 , type = int , help = 'Total number of iterations in training.' )
42+ add_arg ('--epoch-size' , default = 72 , type = int , help = 'Number of batches trained in an epoch.' )
43+ add_arg ('--save-every' , default = 10 , type = int , help = 'Save generator after every training epoch.' )
44+ add_arg ('--batch-shape' , default = 192 , type = int , help = 'Resolution of images in training batch.' )
4345add_arg ('--batch-size' , default = 15 , type = int , help = 'Number of images per training batch.' )
4446add_arg ('--buffer-size' , default = 1500 , type = int , help = 'Total image fragments kept in cache.' )
4547add_arg ('--buffer-similar' , default = 5 , type = int , help = 'Fragments cached for each image loaded.' )
46- add_arg ('--epochs' , default = 10 , type = int , help = 'Total number of iterations in training.' )
47- add_arg ('--epoch-size' , default = 72 , type = int , help = 'Number of batches trained in an epoch.' )
4848add_arg ('--learning-rate' , default = 1E-4 , type = float , help = 'Parameter for the ADAM optimizer.' )
4949add_arg ('--learning-period' , default = 50 , type = int , help = 'How often to decay the learning rate.' )
5050add_arg ('--learning-decay' , default = 0.5 , type = float , help = 'How much to decay the learning rate.' )
@@ -129,11 +129,10 @@ def __init__(self):
129129 self .data_ready = threading .Event ()
130130 self .data_copied = threading .Event ()
131131
132- self .resolution = args .batch_resolution
133- self .seed_resolution = int (args .batch_resolution / 2 ** args .scales )
132+ self .orig_shape , self .seed_shape = args .batch_shape , int (args .batch_shape / 2 ** args .scales )
134133
135- self .buffer = np .zeros ((args .buffer_size , 3 , self .resolution , self .resolution ), dtype = np .float32 )
136- self .seed_buffer = np .zeros ((args .buffer_size , 3 , self .seed_resolution , self .seed_resolution ), dtype = np .float32 )
134+ self .orig_buffer = np .zeros ((args .buffer_size , 3 , self .orig_shape , self .orig_shape ), dtype = np .float32 )
135+ self .seed_buffer = np .zeros ((args .buffer_size , 3 , self .seed_shape , self .seed_shape ), dtype = np .float32 )
137136 self .files = glob .glob (args .train )
138137 if len (self .files ) == 0 :
139138 error ("There were no files found to train from searching for `{}`" .format (args .train ),
@@ -161,31 +160,29 @@ def run(self):
161160
162161 for _ in range (args .buffer_similar ):
163162 copy = img [:,::- 1 ] if random .choice ([True , False ]) else img
164- h = random .randint (0 , copy .shape [0 ] - self .resolution )
165- w = random .randint (0 , copy .shape [1 ] - self .resolution )
166- copy = copy [h :h + self .resolution , w :w + self .resolution ]
163+ h = random .randint (0 , copy .shape [0 ] - self .orig_shape )
164+ w = random .randint (0 , copy .shape [1 ] - self .orig_shape )
165+ copy = copy [h :h + self .orig_shape , w :w + self .orig_shape ]
167166
168167 while len (self .available ) == 0 :
169168 self .data_copied .wait ()
170169 self .data_copied .clear ()
171170
172171 i = self .available .pop ()
173- self .buffer [i ] = np .transpose (copy / 255.0 - 0.5 , (2 , 0 , 1 ))
174- seed_copy = scipy .misc .imresize (copy ,
175- size = (self .seed_resolution , self .seed_resolution ),
176- interp = 'bilinear' )
177- self .seed_buffer [i ] = np .transpose (seed_copy / 255.0 - 0.5 , (2 , 0 , 1 ))
172+ self .orig_buffer [i ] = np .transpose (copy / 255.0 - 0.5 , (2 , 0 , 1 ))
173+ seed = scipy .misc .imresize (copy , size = (self .seed_shape , self .seed_shape ), interp = 'bilinear' )
174+ self .seed_buffer [i ] = np .transpose (seed / 255.0 - 0.5 , (2 , 0 , 1 ))
178175 self .ready .add (i )
179176
180177 if len (self .ready ) >= args .batch_size :
181178 self .data_ready .set ()
182179
183- def copy (self , images_out , seeds_out ):
180+ def copy (self , origs_out , seeds_out ):
184181 self .data_ready .wait ()
185182 self .data_ready .clear ()
186183
187184 for i , j in enumerate (random .sample (self .ready , args .batch_size )):
188- images_out [i ] = self .buffer [j ]
185+ origs_out [i ] = self .orig_buffer [j ]
189186 seeds_out [i ] = self .seed_buffer [j ]
190187
191188 self .available .add (j )
@@ -384,11 +381,10 @@ def loss_discriminator(self, d):
384381
385382 def compile (self ):
386383 # Helper function for rendering test images during training, or standalone non-training mode.
387- input_tensor = T .tensor4 ()
388- seed_tensor = T .tensor4 ()
384+ input_tensor , seed_tensor = T .tensor4 (), T .tensor4 ()
389385 input_layers = {self .network ['img' ]: input_tensor , self .network ['seed' ]: seed_tensor }
390- output = lasagne .layers .get_output ([self .network [k ] for k in ['img' , ' seed' , 'out' ]], input_layers , deterministic = True )
391- self .predict = theano .function ([input_tensor , seed_tensor ], output )
386+ output = lasagne .layers .get_output ([self .network [k ] for k in ['seed' , 'out' ]], input_layers , deterministic = True )
387+ self .predict = theano .function ([seed_tensor ], output )
392388
393389 if not args .train : return
394390
@@ -455,12 +451,12 @@ def decay_learning_rate(self):
455451 if t_cur % args .learning_period == 0 : l_r *= args .learning_decay
456452
457453 def train (self ):
458- seed_size = int (args .batch_resolution / 2 ** args .scales )
459- images = np .zeros ((args .batch_size , 3 , args .batch_resolution , args .batch_resolution ), dtype = np .float32 )
454+ seed_size = int (args .batch_shape / 2 ** args .scales )
455+ images = np .zeros ((args .batch_size , 3 , args .batch_shape , args .batch_shape ), dtype = np .float32 )
460456 seeds = np .zeros ((args .batch_size , 3 , seed_size , seed_size ), dtype = np .float32 )
461457 learning_rate = self .decay_learning_rate ()
462458 try :
463- running , start = None , time .time ()
459+ average , start = None , time .time ()
464460 for epoch in range (args .epochs ):
465461 total , stats = None , None
466462 l_r = next (learning_rate )
@@ -475,11 +471,11 @@ def train(self):
475471 total = total + losses if total is not None else losses
476472 l = np .sum (losses )
477473 assert not np .isnan (losses ).any ()
478- running = l if running is None else running * 0.95 + 0.05 * l
479- print ('↑' if l > running else '↓' , end = '' , flush = True )
474+ average = l if average is None else average * 0.95 + 0.05 * l
475+ print ('↑' if l > average else '↓' , end = '' , flush = True )
480476
481- orign , scald , repro = self .model .predict (images , seeds )
482- self .show_progress (orign , scald , repro )
477+ scald , repro = self .model .predict (seeds )
478+ self .show_progress (images , scald , repro )
483479 total /= args .epoch_size
484480 stats /= args .epoch_size
485481 totals , labels = [sum (total )] + list (total ), ['total' , 'prcpt' , 'smthn' , 'advrs' ]
@@ -490,10 +486,11 @@ def train(self):
490486 real , fake = stats [:args .batch_size ], stats [args .batch_size :]
491487 print (' - discriminator' , real .mean (), len (np .where (real > 0.5 )[0 ]), fake .mean (), len (np .where (fake < - 0.5 )[0 ]))
492488 if epoch == args .adversarial_start - 1 :
493- print (' - adversary mode: generator engaging discriminator.' )
489+ print (' - generator now optimizing against discriminator.' )
494490 self .model .adversary_weight .set_value (args .adversary_weight )
495491 running = None
496- if args .save_every_epoch :
492+ if (epoch + 1 ) % args .save_every == 0 :
493+ print (' - saving current generator layers to disk...' )
497494 self .model .save_generator ()
498495
499496 except KeyboardInterrupt :
@@ -506,7 +503,7 @@ def train(self):
506503
507504 def process (self , image ):
508505 img = np .transpose (image / 255.0 - 0.5 , (2 , 0 , 1 ))[np .newaxis ].astype (np .float32 )
509- * _ , repro = self .model .predict (img , img )
506+ * _ , repro = self .model .predict (img )
510507 repro = np .transpose (repro [0 ] + 0.5 , (1 , 2 , 0 )).clip (0.0 , 1.0 )
511508 return scipy .misc .toimage (repro * 255.0 , cmin = 0 , cmax = 255 )
512509
@@ -516,11 +513,9 @@ def process(self, image):
516513
517514 if args .train :
518515 enhancer .train ()
519-
520- for filename in args .files :
521- print (filename )
522- out = enhancer .process (scipy .ndimage .imread (filename , mode = 'RGB' ))
523- out .save (os .path .splitext (filename )[0 ]+ '_ne%ix.png' % (2 ** args .scales ))
524-
525- if args .files :
516+ else :
517+ for filename in args .files :
518+ print (filename )
519+ out = enhancer .process (scipy .ndimage .imread (filename , mode = 'RGB' ))
520+ out .save (os .path .splitext (filename )[0 ]+ '_ne%ix.png' % (2 ** args .scales ))
526521 print (ansi .ENDC )
0 commit comments