Skip to content
This repository was archived by the owner on Jan 2, 2021. It is now read-only.

Commit 410b0db

Browse files
committed
Now using padding based on reflection rather than zero padding. It's slower but improves quality.
1 parent 62a200f commit 410b0db

File tree

1 file changed

+27
-1
lines changed

1 file changed

+27
-1
lines changed

enhance.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,30 @@ def get_output_for(self, input, deterministic=False, **kwargs):
271271
return out
272272

273273

274+
class ReflectLayer(lasagne.layers.Layer):
275+
"""Based on more code by ajbrock: https://gist.github.com/ajbrock/a3858c26282d9731191901b397b3ce9f
276+
"""
277+
278+
def __init__(self, incoming, pad, batch_ndim=2, **kwargs):
279+
super(ReflectLayer, self).__init__(incoming, **kwargs)
280+
self.pad = pad
281+
self.batch_ndim = batch_ndim
282+
283+
def get_output_shape_for(self, input_shape):
284+
output_shape = list(input_shape)
285+
for k, p in enumerate(self.pad):
286+
if output_shape[k + self.batch_ndim] is None: continue
287+
output_shape[k + self.batch_ndim] += p * 2
288+
return tuple(output_shape)
289+
290+
def get_output_for(self, x, **kwargs):
291+
out = T.zeros(self.get_output_shape_for(x.shape))
292+
p0, p1 = self.pad
293+
out = T.set_subtensor(out[:,:,:p0,p1:-p1], x[:,:,p0:0:-1,:])
294+
out = T.set_subtensor(out[:,:,-p0:,p1:-p1], x[:,:,-2:-(2+p0):-1,:])
295+
out = T.set_subtensor(out[:,:,p0:-p0,p1:-p1], x)
296+
out = T.set_subtensor(out[:,:,:,:p1], out[:,:,:,(2*p1):p1:-1])
297+
out = T.set_subtensor(out[:,:,:,-p1:], out[:,:,:,-(p1+2):-(2*p1+2):-1])
274298
return out
275299

276300

@@ -305,7 +329,9 @@ def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1
305329
extra = {'W': self.network[clone+'x'].W, 'b': self.network[clone+'x'].b}
306330
else:
307331
extra = {}
308-
conv = ConvLayer(input, units, filter_size, stride=stride, pad=pad, nonlinearity=None, **extra)
332+
333+
padded = ReflectLayer(input, pad) if pad[0] > 0 and pad[1] > 0 else input
334+
conv = ConvLayer(padded, units, filter_size, stride=stride, pad=0, nonlinearity=None, **extra)
309335
self.network[name+'x'] = conv
310336

311337
if reuse and clone+'>' in self.network:

0 commit comments

Comments
 (0)