Skip to content
This repository was archived by the owner on Jan 2, 2021. It is now read-only.

Commit 17fcad8

Browse files
committed
Refactor of changes related to training.
1 parent 2b67dae commit 17fcad8

File tree

1 file changed

+34
-39
lines changed

1 file changed

+34
-39
lines changed

enhance.py

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,13 @@
3838
add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.')
3939
add_arg('--model', default='small', type=str, help='Name of the neural network to load/save.')
4040
add_arg('--train', default=False, type=str, help='File pattern to load for training.')
41-
add_arg('--save-every-epoch', default=False, action='store_true', help='Save generator after every training epoch.')
42-
add_arg('--batch-resolution', default=192, type=int, help='Resolution of images in training batch.')
41+
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
42+
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
43+
add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.')
44+
add_arg('--batch-shape', default=192, type=int, help='Resolution of images in training batch.')
4345
add_arg('--batch-size', default=15, type=int, help='Number of images per training batch.')
4446
add_arg('--buffer-size', default=1500, type=int, help='Total image fragments kept in cache.')
4547
add_arg('--buffer-similar', default=5, type=int, help='Fragments cached for each image loaded.')
46-
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
47-
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
4848
add_arg('--learning-rate', default=1E-4, type=float, help='Parameter for the ADAM optimizer.')
4949
add_arg('--learning-period', default=50, type=int, help='How often to decay the learning rate.')
5050
add_arg('--learning-decay', default=0.5, type=float, help='How much to decay the learning rate.')
@@ -129,11 +129,10 @@ def __init__(self):
129129
self.data_ready = threading.Event()
130130
self.data_copied = threading.Event()
131131

132-
self.resolution = args.batch_resolution
133-
self.seed_resolution = int(args.batch_resolution / 2**args.scales)
132+
self.orig_shape, self.seed_shape = args.batch_shape, int(args.batch_shape / 2**args.scales)
134133

135-
self.buffer = np.zeros((args.buffer_size, 3, self.resolution, self.resolution), dtype=np.float32)
136-
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_resolution, self.seed_resolution), dtype=np.float32)
134+
self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32)
135+
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32)
137136
self.files = glob.glob(args.train)
138137
if len(self.files) == 0:
139138
error("There were no files found to train from searching for `{}`".format(args.train),
@@ -161,31 +160,29 @@ def run(self):
161160

162161
for _ in range(args.buffer_similar):
163162
copy = img[:,::-1] if random.choice([True, False]) else img
164-
h = random.randint(0, copy.shape[0] - self.resolution)
165-
w = random.randint(0, copy.shape[1] - self.resolution)
166-
copy = copy[h:h+self.resolution, w:w+self.resolution]
163+
h = random.randint(0, copy.shape[0] - self.orig_shape)
164+
w = random.randint(0, copy.shape[1] - self.orig_shape)
165+
copy = copy[h:h+self.orig_shape, w:w+self.orig_shape]
167166

168167
while len(self.available) == 0:
169168
self.data_copied.wait()
170169
self.data_copied.clear()
171170

172171
i = self.available.pop()
173-
self.buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
174-
seed_copy = scipy.misc.imresize(copy,
175-
size=(self.seed_resolution, self.seed_resolution),
176-
interp='bilinear')
177-
self.seed_buffer[i] = np.transpose(seed_copy / 255.0 - 0.5, (2, 0, 1))
172+
self.orig_buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
173+
seed = scipy.misc.imresize(copy, size=(self.seed_shape, self.seed_shape), interp='bilinear')
174+
self.seed_buffer[i] = np.transpose(seed / 255.0 - 0.5, (2, 0, 1))
178175
self.ready.add(i)
179176

180177
if len(self.ready) >= args.batch_size:
181178
self.data_ready.set()
182179

183-
def copy(self, images_out, seeds_out):
180+
def copy(self, origs_out, seeds_out):
184181
self.data_ready.wait()
185182
self.data_ready.clear()
186183

187184
for i, j in enumerate(random.sample(self.ready, args.batch_size)):
188-
images_out[i] = self.buffer[j]
185+
origs_out[i] = self.orig_buffer[j]
189186
seeds_out[i] = self.seed_buffer[j]
190187

191188
self.available.add(j)
@@ -384,11 +381,10 @@ def loss_discriminator(self, d):
384381

385382
def compile(self):
386383
# Helper function for rendering test images during training, or standalone non-training mode.
387-
input_tensor = T.tensor4()
388-
seed_tensor = T.tensor4()
384+
input_tensor, seed_tensor = T.tensor4(), T.tensor4()
389385
input_layers = {self.network['img']: input_tensor, self.network['seed']: seed_tensor}
390-
output = lasagne.layers.get_output([self.network[k] for k in ['img', 'seed', 'out']], input_layers, deterministic=True)
391-
self.predict = theano.function([input_tensor, seed_tensor], output)
386+
output = lasagne.layers.get_output([self.network[k] for k in ['seed', 'out']], input_layers, deterministic=True)
387+
self.predict = theano.function([seed_tensor], output)
392388

393389
if not args.train: return
394390

@@ -455,12 +451,12 @@ def decay_learning_rate(self):
455451
if t_cur % args.learning_period == 0: l_r *= args.learning_decay
456452

457453
def train(self):
458-
seed_size = int(args.batch_resolution / 2**args.scales)
459-
images = np.zeros((args.batch_size, 3, args.batch_resolution, args.batch_resolution), dtype=np.float32)
454+
seed_size = int(args.batch_shape / 2**args.scales)
455+
images = np.zeros((args.batch_size, 3, args.batch_shape, args.batch_shape), dtype=np.float32)
460456
seeds = np.zeros((args.batch_size, 3, seed_size, seed_size), dtype=np.float32)
461457
learning_rate = self.decay_learning_rate()
462458
try:
463-
running, start = None, time.time()
459+
average, start = None, time.time()
464460
for epoch in range(args.epochs):
465461
total, stats = None, None
466462
l_r = next(learning_rate)
@@ -475,11 +471,11 @@ def train(self):
475471
total = total + losses if total is not None else losses
476472
l = np.sum(losses)
477473
assert not np.isnan(losses).any()
478-
running = l if running is None else running * 0.95 + 0.05 * l
479-
print('↑' if l > running else '↓', end='', flush=True)
474+
average = l if average is None else average * 0.95 + 0.05 * l
475+
print('↑' if l > average else '↓', end='', flush=True)
480476

481-
orign, scald, repro = self.model.predict(images, seeds)
482-
self.show_progress(orign, scald, repro)
477+
scald, repro = self.model.predict(seeds)
478+
self.show_progress(images, scald, repro)
483479
total /= args.epoch_size
484480
stats /= args.epoch_size
485481
totals, labels = [sum(total)] + list(total), ['total', 'prcpt', 'smthn', 'advrs']
@@ -490,10 +486,11 @@ def train(self):
490486
real, fake = stats[:args.batch_size], stats[args.batch_size:]
491487
print(' - discriminator', real.mean(), len(np.where(real > 0.5)[0]), fake.mean(), len(np.where(fake < -0.5)[0]))
492488
if epoch == args.adversarial_start-1:
493-
print(' - adversary mode: generator engaging discriminator.')
489+
print(' - generator now optimizing against discriminator.')
494490
self.model.adversary_weight.set_value(args.adversary_weight)
495491
running = None
496-
if args.save_every_epoch:
492+
if (epoch+1) % args.save_every == 0:
493+
print(' - saving current generator layers to disk...')
497494
self.model.save_generator()
498495

499496
except KeyboardInterrupt:
@@ -506,7 +503,7 @@ def train(self):
506503

507504
def process(self, image):
508505
img = np.transpose(image / 255.0 - 0.5, (2, 0, 1))[np.newaxis].astype(np.float32)
509-
*_, repro = self.model.predict(img, img)
506+
*_, repro = self.model.predict(img)
510507
repro = np.transpose(repro[0] + 0.5, (1, 2, 0)).clip(0.0, 1.0)
511508
return scipy.misc.toimage(repro * 255.0, cmin=0, cmax=255)
512509

@@ -516,11 +513,9 @@ def process(self, image):
516513

517514
if args.train:
518515
enhancer.train()
519-
520-
for filename in args.files:
521-
print(filename)
522-
out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB'))
523-
out.save(os.path.splitext(filename)[0]+'_ne%ix.png'%(2**args.scales))
524-
525-
if args.files:
516+
else:
517+
for filename in args.files:
518+
print(filename)
519+
out = enhancer.process(scipy.ndimage.imread(filename, mode='RGB'))
520+
out.save(os.path.splitext(filename)[0]+'_ne%ix.png'%(2**args.scales))
526521
print(ansi.ENDC)

0 commit comments

Comments
 (0)