Skip to content
This repository was archived by the owner on Jan 2, 2021. It is now read-only.

Commit da354f3

Browse files
committed
Remove reflection layer, important but not the bottleneck right now and quite slower.
1 parent 9bba985 commit da354f3

File tree

1 file changed

+1
-30
lines changed

1 file changed

+1
-30
lines changed

enhance.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -242,34 +242,6 @@ def get_output_for(self, input, deterministic=False, **kwargs):
242242
return out
243243

244244

245-
class ReflectLayer(lasagne.layers.Layer):
246-
"""Based on more code by ajbrock: https://gist.github.com/ajbrock/a3858c26282d9731191901b397b3ce9f
247-
"""
248-
249-
def __init__(self, incoming, pad, batch_ndim=2, **kwargs):
250-
super(ReflectLayer, self).__init__(incoming, **kwargs)
251-
self.pad = (pad, pad)
252-
self.batch_ndim = batch_ndim
253-
254-
def get_output_shape_for(self, input_shape):
255-
output_shape = list(input_shape)
256-
for k, p in enumerate(self.pad):
257-
if output_shape[k + self.batch_ndim] is None: continue
258-
l, r = p, p
259-
output_shape[k + self.batch_ndim] += l + r
260-
return tuple(output_shape)
261-
262-
def get_output_for(self, x, **kwargs):
263-
out = T.zeros(self.get_output_shape_for(x.shape))
264-
p0, p1 = self.pad
265-
out = T.set_subtensor(out[:,:,:p0,p1:-p1], x[:,:,p0:0:-1,:])
266-
out = T.set_subtensor(out[:,:,-p0:,p1:-p1], x[:,:,-2:-(2+p0):-1,:])
267-
out = T.set_subtensor(out[:,:,p0:-p0,p1:-p1], x)
268-
out = T.set_subtensor(out[:,:,:,:p1], out[:,:,:,(2*p1):p1:-1])
269-
out = T.set_subtensor(out[:,:,:,-p1:], out[:,:,:,-(p1+2):-(2*p1+2):-1])
270-
return out
271-
272-
273245
class Model(object):
274246

275247
def __init__(self):
@@ -296,8 +268,7 @@ def last_layer(self):
296268
return list(self.network.values())[-1]
297269

298270
def make_layer(self, name, input, units, filter_size=(3,3), stride=(1,1), pad=(1,1), alpha=0.25):
299-
reflected = ReflectLayer(input, pad=pad[0]) if pad[0] > 0 else input
300-
conv = ConvLayer(reflected, units, filter_size, stride=stride, pad=(0,0), nonlinearity=None)
271+
conv = ConvLayer(input, units, filter_size, stride=stride, pad=(0,0), nonlinearity=None)
301272
prelu = lasagne.layers.ParametricRectifierLayer(conv, alpha=lasagne.init.Constant(alpha))
302273
self.network[name+'x'] = conv
303274
self.network[name+'>'] = prelu

0 commit comments

Comments
 (0)