Skip to content
This repository was archived by the owner on Jan 2, 2021. It is now read-only.

Commit cb55929

Browse files
committed
Add plugin support, refactor loader code accordingly. Use example with --train-plugin=simple and put your images in data/.
1 parent 01a914e commit cb55929

File tree

2 files changed

+54
-12
lines changed

2 files changed

+54
-12
lines changed

enhance.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
1515
#
1616

17-
__version__ = '0.3'
17+
__version__ = '0.4'
1818

1919
import io
2020
import os
@@ -47,6 +47,7 @@
4747
add_arg('--train-blur', default=[], nargs='+', type=int, help='Sigma value for gaussian blur, min/max.')
4848
add_arg('--train-noise', default=None, type=float, help='Distribution for gaussian noise preprocess.')
4949
add_arg('--train-jpeg', default=[], nargs='+', type=int, help='JPEG compression level, specify min/max.')
50+
add_arg('--train-plugin', default=None, type=str, help='Filename for python pre-processing script.')
5051
add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.')
5152
add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.')
5253
add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.')
@@ -139,58 +140,83 @@ def __init__(self):
139140
self.data_ready = threading.Event()
140141
self.data_copied = threading.Event()
141142

143+
if args.train_plugin is not None:
144+
import importlib.util
145+
spec = importlib.util.spec_from_file_location('enhance.plugin', 'plugins/{}.py'.format(args.train_plugin))
146+
plugin = importlib.util.module_from_spec(spec)
147+
spec.loader.exec_module(plugin)
148+
149+
self.iterate_files = plugin.iterate_files
150+
self.load_original = plugin.load_original
151+
self.load_seed = plugin.load_seed
152+
142153
self.orig_shape, self.seed_shape = args.batch_shape, args.batch_shape // args.zoom
143154

144155
self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32)
145156
self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32)
146157
self.files = glob.glob(args.train)
147158
if len(self.files) == 0:
148-
error("There were no files found to train from searching for `{}`".format(args.train),
149-
" - Try putting all your images in one folder and using `--train=data/*.jpg`")
159+
error('There were no files found to train from searching for `{}`'.format(args.train),
160+
' - Try putting all your images in one folder and using `--train="data/*.jpg"`')
150161

151162
self.available = set(range(args.buffer_size))
152163
self.ready = set()
153164

154165
self.cwd = os.getcwd()
155166
self.start()
156167

157-
def run(self):
168+
def iterate_files(self):
158169
while True:
159170
random.shuffle(self.files)
160171
for f in self.files:
161-
self.add_to_buffer(f)
172+
yield f
162173

163-
def add_to_buffer(self, f):
164-
filename = os.path.join(self.cwd, f)
174+
def load_original(self, filename):
165175
try:
166176
orig = PIL.Image.open(filename).convert('RGB')
167177
scale = 2 ** random.randint(args.train_scales[0], args.train_scales[-1])
168178
if scale > 1 and all(s//scale >= args.batch_shape for s in orig.size):
169179
orig = orig.resize((orig.size[0]//scale, orig.size[1]//scale), resample=random.randint(0,3))
170180
if any(s < args.batch_shape for s in orig.size):
171181
raise ValueError('Image is too small for training with size {}'.format(orig.size))
182+
return scipy.misc.fromimage(orig).astype(np.float32)
172183
except Exception as e:
173184
warn('Could not load `{}` as image.'.format(filename),
174185
' - Try fixing or removing the file before next run.')
175186
self.files.remove(f)
176-
return
187+
return None
177188

178-
seed = orig
189+
def load_seed(self, filename, original, zoom):
190+
seed = scipy.misc.toimage(original)
179191
if len(args.train_blur):
180192
seed = seed.filter(PIL.ImageFilter.GaussianBlur(radius=random.randint(args.train_blur[0], args.train_blur[-1])))
181193
if args.zoom > 1:
182-
seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=random.randint(0,3))
194+
seed = seed.resize((seed.size[0]//zoom, seed.size[1]//zoom), resample=random.randint(0,3))
195+
183196
if len(args.train_jpeg) > 0:
184197
buffer = io.BytesIO()
185198
seed.save(buffer, format='jpeg', quality=random.randrange(args.train_jpeg[0], args.train_jpeg[-1]))
186199
seed = PIL.Image.open(buffer)
187200

188-
orig = scipy.misc.fromimage(orig).astype(np.float32)
189201
seed = scipy.misc.fromimage(seed).astype(np.float32)
190-
191202
if args.train_noise is not None:
192203
seed += scipy.random.normal(scale=args.train_noise, size=(seed.shape[0], seed.shape[1], 1))
204+
return seed
205+
206+
def run(self):
207+
for filename in self.iterate_files():
208+
f = os.path.join(self.cwd, filename)
209+
orig = self.load_original(f)
210+
if orig is None: continue
211+
212+
seed = self.load_seed(f, orig, args.zoom)
213+
if seed is None: continue
214+
215+
self.enqueue(orig, seed)
216+
217+
raise ValueError('Insufficient number of files found for training.')
193218

219+
def enqueue(self, orig, seed):
194220
for _ in range(seed.shape[0] * seed.shape[1] // (args.buffer_fraction * self.seed_shape ** 2)):
195221
h = random.randint(0, seed.shape[0] - self.seed_shape)
196222
w = random.randint(0, seed.shape[1] - self.seed_shape)

plugins/simple.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import glob
2+
import itertools
3+
4+
import scipy.misc
5+
import scipy.ndimage
6+
7+
8+
def iterate_files():
9+
return itertools.cycle(glob.glob('data/*.jpg'))
10+
11+
def load_original(filename):
12+
return scipy.ndimage.imread(filename, mode='RGB')
13+
14+
def load_seed(filename, original, zoom):
15+
target_shape = (original.shape[0]//zoom, original.shape[1]//zoom)
16+
return scipy.misc.imresize(original, target_shape, interp='bilinear')

0 commit comments

Comments
 (0)