Skip to content

Commit 348abec

Browse files
authored
Merge pull request #102 from lucasb-eyer/upsample-perforated
Implement "perforated" upsampling mode.
2 parents 295d18b + 9723357 commit 348abec

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

DeepFried2/layers/UpSample.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
class UpSample(df.Module):
5-
def __init__(self, upsample=(2,2), axes=[-2, -1], output_shape=None):
5+
def __init__(self, upsample=(2,2), axes=[-2, -1], output_shape=None, mode='repeat'):
66
"""
77
Upsamples an input (nearest-neighbour, repeat) `upsample` times along `axes`.
88
@@ -11,19 +11,35 @@ def __init__(self, upsample=(2,2), axes=[-2, -1], output_shape=None):
1111
If all `axes` should be upsampled by the same factor, `upsample` can also be just that factor.
1212
- `output_shape` can be used to crop the upsampled result to a desired shape, it follows `axes` just like `upsample` does.
1313
(TODO: move this part into a separate `Restrict` module)
14+
- `mode` specifies how upsampling happens. Currently supported are:
15+
- `repeat`: upsample by repeating values, aka "nearest" (the default).
16+
- `perforated`: Put values in top-left corner, i.e. [1,2] becomes [1,0,2,0].
17+
This comes from http://www.brml.org/uploads/tx_sibibtex/281.pdf
1418
"""
1519
df.Module.__init__(self)
1620
self.axes = axes
1721
self.upsample = df.utils.expand(upsample, len(axes), "upsample factor")
1822
self.output_shape = df.utils.expand(output_shape, len(axes), "output shape")
23+
self.upsample_mode = mode
1924

2025
def symb_forward(self, symb_input):
2126
"""symb_input shape: 2D: (n_input, channels, height, width)
2227
3D: (n_input, channels, depth, height, width)
2328
"""
24-
res = symb_input
25-
for f, ax in zip(self.upsample, self.axes):
26-
res = df.T.repeat(res, f, axis=ax)
29+
if self.upsample_mode == 'repeat':
30+
res = symb_input
31+
for f, ax in zip(self.upsample, self.axes):
32+
res = df.T.repeat(res, f, axis=ax)
33+
elif self.upsample_mode == 'perforated':
34+
shape = list(symb_input.shape)
35+
slices = [slice(None)]*symb_input.ndim
36+
for f, ax in zip(self.upsample, self.axes):
37+
shape[ax] *= f
38+
slices[ax] = slice(None, None, f)
39+
res = df.T.zeros(shape, symb_input.dtype)
40+
res = df.T.set_subtensor(res[tuple(slices)], symb_input)
41+
else:
42+
raise ValueError("Unsupported upsampling mode '{}'".format(self.upsample_mode))
2743

2844
# TODO: move this out to its own `Restrict` module.
2945
if self.output_shape is not None:

0 commit comments

Comments
 (0)