Skip to content
This repository was archived by the owner on Jan 2, 2021. It is now read-only.

Commit 637398c

Browse files
committed
Added -seeds option to provide seeds from disk
1 parent 5ef872b commit 637398c

File tree

1 file changed

+72
-10
lines changed

1 file changed

+72
-10
lines changed

enhance.py

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
add_arg('--scales', default=2, type=int, help='How many times to perform 2x upsampling.')
3939
add_arg('--model', default='small', type=str, help='Name of the neural network to load/save.')
4040
add_arg('--train', default=False, type=str, help='File pattern to load for training.')
41+
add_arg('--seeds', default=False, type=str, help='File pattern to load for training seeds.')
4142
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
4243
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
4344
add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.')
@@ -118,6 +119,28 @@ def extend(lst): return itertools.chain(lst, itertools.repeat(lst[-1]))
118119

119120
print('{} - Using the device `{}` for neural computation.{}\n'.format(ansi.CYAN, theano.config.device, ansi.ENDC))
120121

122+
def confirm_pairs(list1, list2):
123+
new_list1 = []
124+
new_list2 = []
125+
cur1 = 0
126+
cur2 = 0
127+
len1 = len(list1)
128+
len2 = len(list2)
129+
while(cur1 < len1 and cur2 < len2):
130+
base1 = os.path.basename(list1[cur1])
131+
base2 = os.path.basename(list2[cur2])
132+
if base1 == base2:
133+
new_list1.append(list1[cur1])
134+
new_list2.append(list2[cur2])
135+
cur1 = cur1 + 1
136+
cur2 = cur2 + 1
137+
elif base1 < base2:
138+
# continue to look on list1, don't iterate list2
139+
cur1 = cur1 + 1
140+
else:
141+
cur2 = cur2 + 1
142+
print("List sizes went from {}, {} to {}, {}".format(len1, len2, len(new_list1), len(new_list2)))
143+
return new_list1, new_list2
121144

122145
#======================================================================================================================
123146
# Image Processing
@@ -133,7 +156,13 @@ def __init__(self):
133156

134157
self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32)
135158
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32)
136-
self.files = glob.glob(args.train)
159+
self.files = sorted(glob.glob(args.train))
160+
if args.seeds:
161+
self.seeds = sorted(glob.glob(args.seeds))
162+
self.files, self.seeds = confirm_pairs(self.files, self.seeds)
163+
else:
164+
self.seeds = False
165+
137166
if len(self.files) == 0:
138167
error("There were no files found to train from searching for `{}`".format(args.train),
139168
" - Try putting all your images in one folder and using `--train=data/*.jpg`")
@@ -146,22 +175,56 @@ def __init__(self):
146175

147176
def run(self):
148177
while True:
149-
random.shuffle(self.files)
178+
indices = list(range(0, len(self.files)))
179+
random.shuffle(indices)
150180

151-
for f in self.files:
181+
for i in indices:
182+
f = self.files[i]
152183
filename = os.path.join(self.cwd, f)
153184
try:
154185
img = scipy.ndimage.imread(filename, mode='RGB')
155186
except Exception as e:
156187
warn('Could not load `{}` as image.'.format(filename),
157188
' - Try fixing or removing the file before next run.')
158-
files.remove(f)
189+
del self.files[i]
190+
if self.seeds:
191+
del self.seeds[i]
159192
continue
160-
193+
194+
# determine seed
195+
if self.seeds:
196+
f = self.seeds[i]
197+
filename = os.path.join(self.cwd, f)
198+
try:
199+
seed_img = scipy.ndimage.imread(filename, mode='RGB')
200+
except Exception as e:
201+
warn('Could not load `{}` as seed image.'.format(filename),
202+
' - Try fixing or removing the file before next run.')
203+
del self.files[i]
204+
del self.seeds[i]
205+
continue
206+
else:
207+
# synthetic seed
208+
seed_img = scipy.misc.imresize(img,
209+
float(self.seed_shape) / float(self.orig_shape),
210+
interp='bilinear')
211+
161212
for _ in range(args.buffer_similar):
162-
copy = img[:,::-1] if random.choice([True, False]) else img
163-
h = random.randint(0, copy.shape[0] - self.orig_shape)
164-
w = random.randint(0, copy.shape[1] - self.orig_shape)
213+
if random.choice([True, False]):
214+
copy = img[:,::-1]
215+
copy_seed = seed_img[:,::-1]
216+
else:
217+
copy = img
218+
copy_seed = seed_img
219+
220+
# compute seed displacement
221+
h = random.randint(0, copy_seed.shape[0] - self.seed_shape)
222+
w = random.randint(0, copy_seed.shape[1] - self.seed_shape)
223+
copy_seed = copy_seed[h:h+self.seed_shape, w:w+self.seed_shape]
224+
225+
# and matching image displacement
226+
h = int(h * self.orig_shape / self.seed_shape);
227+
w = int(w * self.orig_shape / self.seed_shape);
165228
copy = copy[h:h+self.orig_shape, w:w+self.orig_shape]
166229

167230
while len(self.available) == 0:
@@ -170,8 +233,7 @@ def run(self):
170233

171234
i = self.available.pop()
172235
self.orig_buffer[i] = np.transpose(copy / 255.0 - 0.5, (2, 0, 1))
173-
seed = scipy.misc.imresize(copy, size=(self.seed_shape, self.seed_shape), interp='bilinear')
174-
self.seed_buffer[i] = np.transpose(seed / 255.0 - 0.5, (2, 0, 1))
236+
self.seed_buffer[i] = np.transpose(copy_seed / 255.0 - 0.5, (2, 0, 1))
175237
self.ready.add(i)
176238

177239
if len(self.ready) >= args.batch_size:

0 commit comments

Comments
 (0)