3838add_arg ('--scales' , default = 2 , type = int , help = 'How many times to perform 2x upsampling.' )
3939add_arg ('--model' , default = 'small' , type = str , help = 'Name of the neural network to load/save.' )
4040add_arg ('--train' , default = False , type = str , help = 'File pattern to load for training.' )
41+ add_arg ('--seeds' , default = False , type = str , help = 'File pattern to load for training seeds.' )
4142add_arg ('--epochs' , default = 10 , type = int , help = 'Total number of iterations in training.' )
4243add_arg ('--epoch-size' , default = 72 , type = int , help = 'Number of batches trained in an epoch.' )
4344add_arg ('--save-every' , default = 10 , type = int , help = 'Save generator after every training epoch.' )
@@ -118,6 +119,28 @@ def extend(lst): return itertools.chain(lst, itertools.repeat(lst[-1]))
118119
119120print ('{} - Using the device `{}` for neural computation.{}\n ' .format (ansi .CYAN , theano .config .device , ansi .ENDC ))
120121
122+ def confirm_pairs (list1 , list2 ):
123+ new_list1 = []
124+ new_list2 = []
125+ cur1 = 0
126+ cur2 = 0
127+ len1 = len (list1 )
128+ len2 = len (list2 )
129+ while (cur1 < len1 and cur2 < len2 ):
130+ base1 = os .path .basename (list1 [cur1 ])
131+ base2 = os .path .basename (list2 [cur2 ])
132+ if base1 == base2 :
133+ new_list1 .append (list1 [cur1 ])
134+ new_list2 .append (list2 [cur2 ])
135+ cur1 = cur1 + 1
136+ cur2 = cur2 + 1
137+ elif base1 < base2 :
138+ # continue to look on list1, don't iterate list2
139+ cur1 = cur1 + 1
140+ else :
141+ cur2 = cur2 + 1
142+ print ("List sizes went from {}, {} to {}, {}" .format (len1 , len2 , len (new_list1 ), len (new_list2 )))
143+ return new_list1 , new_list2
121144
122145#======================================================================================================================
123146# Image Processing
@@ -133,7 +156,13 @@ def __init__(self):
133156
134157 self .orig_buffer = np .zeros ((args .buffer_size , 3 , self .orig_shape , self .orig_shape ), dtype = np .float32 )
135158 self .seed_buffer = np .zeros ((args .buffer_size , 3 , self .seed_shape , self .seed_shape ), dtype = np .float32 )
136- self .files = glob .glob (args .train )
159+ self .files = sorted (glob .glob (args .train ))
160+ if args .seeds :
161+ self .seeds = sorted (glob .glob (args .seeds ))
162+ self .files , self .seeds = confirm_pairs (self .files , self .seeds )
163+ else :
164+ self .seeds = False
165+
137166 if len (self .files ) == 0 :
138167 error ("There were no files found to train from searching for `{}`" .format (args .train ),
139168 " - Try putting all your images in one folder and using `--train=data/*.jpg`" )
@@ -146,22 +175,56 @@ def __init__(self):
146175
147176 def run (self ):
148177 while True :
149- random .shuffle (self .files )
178+ indices = list (range (0 , len (self .files )))
179+ random .shuffle (indices )
150180
151- for f in self .files :
181+ for i in indices :
182+ f = self .files [i ]
152183 filename = os .path .join (self .cwd , f )
153184 try :
154185 img = scipy .ndimage .imread (filename , mode = 'RGB' )
155186 except Exception as e :
156187 warn ('Could not load `{}` as image.' .format (filename ),
157188 ' - Try fixing or removing the file before next run.' )
158- files .remove (f )
189+ del self .files [i ]
190+ if self .seeds :
191+ del self .seeds [i ]
159192 continue
160-
193+
194+ # determine seed
195+ if self .seeds :
196+ f = self .seeds [i ]
197+ filename = os .path .join (self .cwd , f )
198+ try :
199+ seed_img = scipy .ndimage .imread (filename , mode = 'RGB' )
200+ except Exception as e :
201+ warn ('Could not load `{}` as seed image.' .format (filename ),
202+ ' - Try fixing or removing the file before next run.' )
203+ del self .files [i ]
204+ del self .seeds [i ]
205+ continue
206+ else :
207+ # synthetic seed
208+ seed_img = scipy .misc .imresize (img ,
209+ float (self .seed_shape ) / float (self .orig_shape ),
210+ interp = 'bilinear' )
211+
161212 for _ in range (args .buffer_similar ):
162- copy = img [:,::- 1 ] if random .choice ([True , False ]) else img
163- h = random .randint (0 , copy .shape [0 ] - self .orig_shape )
164- w = random .randint (0 , copy .shape [1 ] - self .orig_shape )
213+ if random .choice ([True , False ]):
214+ copy = img [:,::- 1 ]
215+ copy_seed = seed_img [:,::- 1 ]
216+ else :
217+ copy = img
218+ copy_seed = seed_img
219+
220+ # compute seed displacement
221+ h = random .randint (0 , copy_seed .shape [0 ] - self .seed_shape )
222+ w = random .randint (0 , copy_seed .shape [1 ] - self .seed_shape )
223+ copy_seed = copy_seed [h :h + self .seed_shape , w :w + self .seed_shape ]
224+
225+ # and matching image displacement
226+ h = int (h * self .orig_shape / self .seed_shape );
227+ w = int (w * self .orig_shape / self .seed_shape );
165228 copy = copy [h :h + self .orig_shape , w :w + self .orig_shape ]
166229
167230 while len (self .available ) == 0 :
@@ -170,8 +233,7 @@ def run(self):
170233
171234 i = self .available .pop ()
172235 self .orig_buffer [i ] = np .transpose (copy / 255.0 - 0.5 , (2 , 0 , 1 ))
173- seed = scipy .misc .imresize (copy , size = (self .seed_shape , self .seed_shape ), interp = 'bilinear' )
174- self .seed_buffer [i ] = np .transpose (seed / 255.0 - 0.5 , (2 , 0 , 1 ))
236+ self .seed_buffer [i ] = np .transpose (copy_seed / 255.0 - 0.5 , (2 , 0 , 1 ))
175237 self .ready .add (i )
176238
177239 if len (self .ready ) >= args .batch_size :
0 commit comments