22
33
44class UpSample (df .Module ):
5- def __init__ (self , upsample = (2 ,2 ), axes = [- 2 , - 1 ], output_shape = None ):
5+ def __init__ (self , upsample = (2 ,2 ), axes = [- 2 , - 1 ], output_shape = None , mode = 'repeat' ):
66 """
77 Upsamples an input (nearest-neighbour, repeat) `upsample` times along `axes`.
88
@@ -11,19 +11,35 @@ def __init__(self, upsample=(2,2), axes=[-2, -1], output_shape=None):
1111 If all `axes` should be upsampled by the same factor, `upsample` can also be just that factor.
1212 - `output_shape` can be used to crop the upsampled result to a desired shape, it follows `axes` just like `upsample` does.
1313 (TODO: move this part into a separate `Restrict` module)
14+ - `mode` specifies how upsampling happens. Currently supported are:
15+ - `repeat`: upsample by repeating values, aka "nearest" (the default).
16+ - `perforated`: Put values in top-left corner, i.e. [1,2] becomes [1,0,2,0].
17+ This comes from http://www.brml.org/uploads/tx_sibibtex/281.pdf
1418 """
1519 df .Module .__init__ (self )
1620 self .axes = axes
1721 self .upsample = df .utils .expand (upsample , len (axes ), "upsample factor" )
1822 self .output_shape = df .utils .expand (output_shape , len (axes ), "output shape" )
23+ self .upsample_mode = mode
1924
2025 def symb_forward (self , symb_input ):
2126 """symb_input shape: 2D: (n_input, channels, height, width)
2227 3D: (n_input, channels, depth, height, width)
2328 """
24- res = symb_input
25- for f , ax in zip (self .upsample , self .axes ):
26- res = df .T .repeat (res , f , axis = ax )
29+ if self .upsample_mode == 'repeat' :
30+ res = symb_input
31+ for f , ax in zip (self .upsample , self .axes ):
32+ res = df .T .repeat (res , f , axis = ax )
33+ elif self .upsample_mode == 'perforated' :
34+ shape = list (symb_input .shape )
35+ slices = [slice (None )]* symb_input .ndim
36+ for f , ax in zip (self .upsample , self .axes ):
37+ shape [ax ] *= f
38+ slices [ax ] = slice (None , None , f )
39+ res = df .T .zeros (shape , symb_input .dtype )
40+ res = df .T .set_subtensor (res [tuple (slices )], symb_input )
41+ else :
42+ raise ValueError ("Unsupported upsampling mode '{}'" .format (self .upsample_mode ))
2743
2844 # TODO: move this out to its own `Restrict` module.
2945 if self .output_shape is not None :
0 commit comments