@@ -271,6 +271,30 @@ def get_output_for(self, input, deterministic=False, **kwargs):
271271 return out
272272
273273
274+ class ReflectLayer (lasagne .layers .Layer ):
275+ """Based on more code by ajbrock: https://gist.github.com/ajbrock/a3858c26282d9731191901b397b3ce9f
276+ """
277+
278+ def __init__ (self , incoming , pad , batch_ndim = 2 , ** kwargs ):
279+ super (ReflectLayer , self ).__init__ (incoming , ** kwargs )
280+ self .pad = pad
281+ self .batch_ndim = batch_ndim
282+
283+ def get_output_shape_for (self , input_shape ):
284+ output_shape = list (input_shape )
285+ for k , p in enumerate (self .pad ):
286+ if output_shape [k + self .batch_ndim ] is None : continue
287+ output_shape [k + self .batch_ndim ] += p * 2
288+ return tuple (output_shape )
289+
290+ def get_output_for (self , x , ** kwargs ):
291+ out = T .zeros (self .get_output_shape_for (x .shape ))
292+ p0 , p1 = self .pad
293+ out = T .set_subtensor (out [:,:,:p0 ,p1 :- p1 ], x [:,:,p0 :0 :- 1 ,:])
294+ out = T .set_subtensor (out [:,:,- p0 :,p1 :- p1 ], x [:,:,- 2 :- (2 + p0 ):- 1 ,:])
295+ out = T .set_subtensor (out [:,:,p0 :- p0 ,p1 :- p1 ], x )
296+ out = T .set_subtensor (out [:,:,:,:p1 ], out [:,:,:,(2 * p1 ):p1 :- 1 ])
297+ out = T .set_subtensor (out [:,:,:,- p1 :], out [:,:,:,- (p1 + 2 ):- (2 * p1 + 2 ):- 1 ])
274298 return out
275299
276300
@@ -305,7 +329,9 @@ def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1
305329 extra = {'W' : self .network [clone + 'x' ].W , 'b' : self .network [clone + 'x' ].b }
306330 else :
307331 extra = {}
308- conv = ConvLayer (input , units , filter_size , stride = stride , pad = pad , nonlinearity = None , ** extra )
332+
333+ padded = ReflectLayer (input , pad ) if pad [0 ] > 0 and pad [1 ] > 0 else input
334+ conv = ConvLayer (padded , units , filter_size , stride = stride , pad = 0 , nonlinearity = None , ** extra )
309335 self .network [name + 'x' ] = conv
310336
311337 if reuse and clone + '>' in self .network :
0 commit comments