|
14 | 14 | # without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. |
15 | 15 | # |
16 | 16 |
|
17 | | -__version__ = '0.3' |
| 17 | +__version__ = '0.4' |
18 | 18 |
|
19 | 19 | import io |
20 | 20 | import os |
|
47 | 47 | add_arg('--train-blur', default=[], nargs='+', type=int, help='Sigma value for gaussian blur, min/max.') |
48 | 48 | add_arg('--train-noise', default=None, type=float, help='Distribution for gaussian noise preprocess.') |
49 | 49 | add_arg('--train-jpeg', default=[], nargs='+', type=int, help='JPEG compression level, specify min/max.') |
| 50 | +add_arg('--train-plugin', default=None, type=str, help='Filename for python pre-processing script.') |
50 | 51 | add_arg('--epochs', default=10, type=int, help='Total number of iterations in training.') |
51 | 52 | add_arg('--epoch-size', default=72, type=int, help='Number of batches trained in an epoch.') |
52 | 53 | add_arg('--save-every', default=10, type=int, help='Save generator after every training epoch.') |
@@ -139,58 +140,83 @@ def __init__(self): |
139 | 140 | self.data_ready = threading.Event() |
140 | 141 | self.data_copied = threading.Event() |
141 | 142 |
|
| 143 | + if args.train_plugin is not None: |
| 144 | + import importlib.util |
| 145 | + spec = importlib.util.spec_from_file_location('enhance.plugin', 'plugins/{}.py'.format(args.train_plugin)) |
| 146 | + plugin = importlib.util.module_from_spec(spec) |
| 147 | + spec.loader.exec_module(plugin) |
| 148 | + |
| 149 | + self.iterate_files = plugin.iterate_files |
| 150 | + self.load_original = plugin.load_original |
| 151 | + self.load_seed = plugin.load_seed |
| 152 | + |
142 | 153 | self.orig_shape, self.seed_shape = args.batch_shape, args.batch_shape // args.zoom |
143 | 154 |
|
144 | 155 | self.orig_buffer = np.zeros((args.buffer_size, 3, self.orig_shape, self.orig_shape), dtype=np.float32) |
145 | 156 | self.seed_buffer = np.zeros((args.buffer_size, 3, self.seed_shape, self.seed_shape), dtype=np.float32) |
146 | 157 | self.files = glob.glob(args.train) |
147 | 158 | if len(self.files) == 0: |
148 | | - error("There were no files found to train from searching for `{}`".format(args.train), |
149 | | - " - Try putting all your images in one folder and using `--train=data/*.jpg`") |
| 159 | + error('There were no files found to train from searching for `{}`'.format(args.train), |
| 160 | + ' - Try putting all your images in one folder and using `--train="data/*.jpg"`') |
150 | 161 |
|
151 | 162 | self.available = set(range(args.buffer_size)) |
152 | 163 | self.ready = set() |
153 | 164 |
|
154 | 165 | self.cwd = os.getcwd() |
155 | 166 | self.start() |
156 | 167 |
|
157 | | - def run(self): |
| 168 | + def iterate_files(self): |
158 | 169 | while True: |
159 | 170 | random.shuffle(self.files) |
160 | 171 | for f in self.files: |
161 | | - self.add_to_buffer(f) |
| 172 | + yield f |
162 | 173 |
|
163 | | - def add_to_buffer(self, f): |
164 | | - filename = os.path.join(self.cwd, f) |
| 174 | + def load_original(self, filename): |
165 | 175 | try: |
166 | 176 | orig = PIL.Image.open(filename).convert('RGB') |
167 | 177 | scale = 2 ** random.randint(args.train_scales[0], args.train_scales[-1]) |
168 | 178 | if scale > 1 and all(s//scale >= args.batch_shape for s in orig.size): |
169 | 179 | orig = orig.resize((orig.size[0]//scale, orig.size[1]//scale), resample=random.randint(0,3)) |
170 | 180 | if any(s < args.batch_shape for s in orig.size): |
171 | 181 | raise ValueError('Image is too small for training with size {}'.format(orig.size)) |
| 182 | + return scipy.misc.fromimage(orig).astype(np.float32) |
172 | 183 | except Exception as e: |
173 | 184 | warn('Could not load `{}` as image.'.format(filename), |
174 | 185 | ' - Try fixing or removing the file before next run.') |
175 | 186 | self.files.remove(f) |
176 | | - return |
| 187 | + return None |
177 | 188 |
|
178 | | - seed = orig |
| 189 | + def load_seed(self, filename, original, zoom): |
| 190 | + seed = scipy.misc.toimage(original) |
179 | 191 | if len(args.train_blur): |
180 | 192 | seed = seed.filter(PIL.ImageFilter.GaussianBlur(radius=random.randint(args.train_blur[0], args.train_blur[-1]))) |
181 | 193 | if args.zoom > 1: |
182 | | - seed = seed.resize((orig.size[0]//args.zoom, orig.size[1]//args.zoom), resample=random.randint(0,3)) |
| 194 | + seed = seed.resize((seed.size[0]//zoom, seed.size[1]//zoom), resample=random.randint(0,3)) |
| 195 | + |
183 | 196 | if len(args.train_jpeg) > 0: |
184 | 197 | buffer = io.BytesIO() |
185 | 198 | seed.save(buffer, format='jpeg', quality=random.randrange(args.train_jpeg[0], args.train_jpeg[-1])) |
186 | 199 | seed = PIL.Image.open(buffer) |
187 | 200 |
|
188 | | - orig = scipy.misc.fromimage(orig).astype(np.float32) |
189 | 201 | seed = scipy.misc.fromimage(seed).astype(np.float32) |
190 | | - |
191 | 202 | if args.train_noise is not None: |
192 | 203 | seed += scipy.random.normal(scale=args.train_noise, size=(seed.shape[0], seed.shape[1], 1)) |
| 204 | + return seed |
| 205 | + |
| 206 | + def run(self): |
| 207 | + for filename in self.iterate_files(): |
| 208 | + f = os.path.join(self.cwd, filename) |
| 209 | + orig = self.load_original(f) |
| 210 | + if orig is None: continue |
| 211 | + |
| 212 | + seed = self.load_seed(f, orig, args.zoom) |
| 213 | + if seed is None: continue |
| 214 | + |
| 215 | + self.enqueue(orig, seed) |
| 216 | + |
| 217 | + raise ValueError('Insufficient number of files found for training.') |
193 | 218 |
|
| 219 | + def enqueue(self, orig, seed): |
194 | 220 | for _ in range(seed.shape[0] * seed.shape[1] // (args.buffer_fraction * self.seed_shape ** 2)): |
195 | 221 | h = random.randint(0, seed.shape[0] - self.seed_shape) |
196 | 222 | w = random.randint(0, seed.shape[1] - self.seed_shape) |
|
0 commit comments