Skip to content
This repository was archived by the owner on Jan 2, 2021. It is now read-only.

Commit a5ad2c2

Browse files
authored
Merge pull request #18 from alexjc/training
Training Improvements
2 parents a9b0cd9 + 17fcad8 commit a5ad2c2

File tree

1 file changed

+44
-38
lines changed

1 file changed

+44
-38
lines changed

enhance.py

Lines changed: 44 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,13 @@
3838
add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.')
3939
add_arg('--model', default='small', type=str, help='Name of the neural network to load/save.')
4040
add_arg('--train', default=False, type=str, help='File pattern to load for training.')
41-
add_arg('--batch-resolution', default=192, type=int, help='Resolution of images in training batch.')
41+
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
42+
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
43+
add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.')
44+
add_arg('--batch-shape', default=192, type=int, help='Resolution of images in training batch.')
4245
add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.')
4346
add_arg('--buffer-size', default=1500, type=int, help='Total image fragments kept in cache.')
4447
add_arg('--buffer-similar', default=5, type=int, help='Fragments cached for each image loaded.')
45-
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
46-
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
4748
add_arg('--learning-rate', default=1E-4, type=float, help='Parameter for the ADAM optimizer.')
4849
add_arg('--learning-period', default=50, type=int, help='How often to decay the learning rate.')
4950
add_arg('--learning-decay', default=0.5, type=float, help='How much to decay the learning rate.')
@@ -128,8 +129,10 @@ def __init__(self):
128129
self.data_ready = threading.Event()
129130
self.data_copied = threading.Event()
130131

131-
self.resolution = args.batch_resolution
132-
self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32)
132+
self.orig_shape, self.seed_shape = args.batch_shape, int(args.batch_shape / 2**args.scales)
133+
134+
self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32)
135+
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32)
133136
self.files = glob.glob(args.train)
134137
if len(self.files) == 0:
135138
error("There were no files found to train from searching for `{}`".format(args.train),
@@ -157,27 +160,31 @@ def run(self):
157160

158161
for _ in range(args.buffer_similar):
159162
copy = img[:,::-1] if random.choice([True, False]) else img
160-
h = random.randint(0, copy.shape[0] - self.resolution)
161-
w = random.randint(0, copy.shape[1] - self.resolution)
162-
copy = copy[h:h+self.resolution, w:w+self.resolution]
163+
h = random.randint(0, copy.shape[0] - self.orig_shape)
164+
w = random.randint(0, copy.shape[1] - self.orig_shape)
165+
copy = copy[h:h+self.orig_shape, w:w+self.orig_shape]
163166

164167
while len(self.available) == 0:
165168
self.data_copied.wait()
166169
self.data_copied.clear()
167170

168171
i = self.available.pop()
169-
self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
172+
self.orig_buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
173+
seed = scipy.misc.imresize(copy, size=(self.seed_shape, self.seed_shape), interp='bilinear')
174+
self.seed_buffer[i] = np.transpose(seed / 255.0 - 0.5, (2, 0, 1))
170175
self.ready.add(i)
171176

172177
if len(self.ready) >= args.batch_size:
173178
self.data_ready.set()
174179

175-
def copy(self, output):
180+
def copy(self, origs_out, seeds_out):
176181
self.data_ready.wait()
177182
self.data_ready.clear()
178183

179184
for i, j in enumerate(random.sample(self.ready, args.batch_size)):
180-
output[i] = self.buffer[j]
185+
origs_out[i] = self.orig_buffer[j]
186+
seeds_out[i] = self.seed_buffer[j]
187+
181188
self.available.add(j)
182189

183190
self.data_copied.set()
@@ -211,12 +218,8 @@ class Model(object):
211218

212219
def __init__(self):
213220
self.network = collections.OrderedDict()
214-
if args.train:
215-
self.network['img'] = InputLayer((None, 3, None, None))
216-
self.network['seed'] = PoolLayer(self.network['img'], pool_size=2**args.scales, mode='average_exc_pad')
217-
else:
218-
self.network['img'] = InputLayer((None, 3, None, None))
219-
self.network['seed'] = self.network['img']
221+
self.network['img'] = InputLayer((None, 3, None, None))
222+
self.network['seed'] = InputLayer((None, 3, None, None))
220223

221224
config, params = self.load_model()
222225
self.setup_generator(self.last_layer(), config)
@@ -378,10 +381,10 @@ def loss_discriminator(self, d):
378381

379382
def compile(self):
380383
# Helper function for rendering test images during training, or standalone non-training mode.
381-
input_tensor = T.tensor4()
382-
input_layers = {self.network['img']: input_tensor}
383-
output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True)
384-
self.predict = theano.function([input_tensor], output)
384+
input_tensor, seed_tensor = T.tensor4(), T.tensor4()
385+
input_layers = {self.network['img']: input_tensor, self.network['seed']: seed_tensor}
386+
output = lasagne.layers.get_output([self.network[k] for k in ['seed', 'out']], input_layers, deterministic=True)
387+
self.predict = theano.function([seed_tensor], output)
385388

386389
if not args.train: return
387390

@@ -407,7 +410,7 @@ def compile(self):
407410

408411
# Combined Theano function for updating both generator and discriminator at the same time.
409412
updates = collections.OrderedDict(list(gen_updates.items()) + list(disc_updates.items()))
410-
self.fit = theano.function([input_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates)
413+
self.fit = theano.function([input_tensor, seed_tensor], gen_losses + [disc_out.mean(axis=(1,2,3))], updates=updates)
411414

412415

413416

@@ -448,29 +451,31 @@ def decay_learning_rate(self):
448451
if t_cur % args.learning_period == 0: l_r *= args.learning_decay
449452

450453
def train(self):
451-
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
454+
seed_size = int(args.batch_shape / 2**args.scales)
455+
images = np.zeros((args.batch_size, 3, args.batch_shape, args.batch_shape), dtype=np.float32)
456+
seeds = np.zeros((args.batch_size, 3, seed_size, seed_size), dtype=np.float32)
452457
learning_rate = self.decay_learning_rate()
453458
try:
454-
running, start = None, time.time()
459+
average, start = None, time.time()
455460
for epoch in range(args.epochs):
456461
total, stats = None, None
457462
l_r = next(learning_rate)
458463
if epoch >= args.generator_start: self.model.gen_lr.set_value(l_r)
459464
if epoch >= args.discriminator_start: self.model.disc_lr.set_value(l_r)
460465

461466
for _ in range(args.epoch_size):
462-
self.thread.copy(images)
463-
output = self.model.fit(images)
467+
self.thread.copy(images, seeds)
468+
output = self.model.fit(images, seeds)
464469
losses = np.array(output[:3], dtype=np.float32)
465470
stats = (stats + output[3]) if stats is not None else output[3]
466471
total = total + losses if total is not None else losses
467472
l = np.sum(losses)
468473
assert not np.isnan(losses).any()
469-
running = l if running is None else running * 0.95 + 0.05 * l
470-
print('↑' if l > running else '↓', end='', flush=True)
474+
average = l if average is None else average * 0.95 + 0.05 * l
475+
print('↑' if l > average else '↓', end='', flush=True)
471476

472-
orign, scald, repro = self.model.predict(images)
473-
self.show_progress(orign, scald, repro)
477+
scald, repro = self.model.predict(seeds)
478+
self.show_progress(images, scald, repro)
474479
total /= args.epoch_size
475480
stats /= args.epoch_size
476481
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
@@ -481,9 +486,12 @@ def train(self):
481486
real, fake = stats[:args.batch_size], stats[args.batch_size:]
482487
print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < -0.5)[0]))
483488
if epoch == args.adversarial_start-1:
484-
print(' - adversary mode: generator engaging discriminator.')
489+
print(' - generator now optimizing against discriminator.')
485490
self.model.adversary_weight.set_value(args.adversary_weight)
486491
running = None
492+
if (epoch+1) % args.save_every == 0:
493+
print(' - saving current generator layers to disk...')
494+
self.model.save_generator()
487495

488496
except KeyboardInterrupt:
489497
pass
@@ -505,11 +513,9 @@ def process(self, image):
505513

506514
if args.train:
507515
enhancer.train()
508-
509-
for filename in args.files:
510-
print(filename)
511-
out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB'))
512-
out.save(os.path.splitext(filename)[0]+'_ne%ix.png'%(2**args.scales))
513-
514-
if args.files:
516+
else:
517+
for filename in args.files:
518+
print(filename)
519+
out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB'))
520+
out.save(os.path.splitext(filename)[0]+'_ne%ix.png'%(2**args.scales))
515521
print(ansi.ENDC)

0 commit comments

Comments
 (0)