1+ import torch
2+
3+ # source: https://github.com/serkansulun/pytorch-pixelshuffle1d/blob/master/pixelshuffle1d.py
4+ # "long" and "short" denote longer and shorter samples
5+
6+ class PixelShuffle1D (torch .nn .Module ):
7+ """
8+ 1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf
9+ Upscales sample length, downscales channel length
10+ "short" is input, "long" is output
11+ """
12+ def __init__ (self , upscale_factor ):
13+ super (PixelShuffle1D , self ).__init__ ()
14+ self .upscale_factor = upscale_factor
15+
16+ def forward (self , x ):
17+ batch_size = x .shape [0 ]
18+ short_channel_len = x .shape [1 ]
19+ short_width = x .shape [2 ]
20+
21+ long_channel_len = short_channel_len // self .upscale_factor
22+ long_width = self .upscale_factor * short_width
23+
24+ x = x .contiguous ().view ([batch_size , self .upscale_factor , long_channel_len , short_width ])
25+ x = x .permute (0 , 2 , 3 , 1 ).contiguous ()
26+ x = x .view (batch_size , long_channel_len , long_width )
27+
28+ return x
29+
30+ class PixelUnshuffle1D (torch .nn .Module ):
31+ """
32+ Inverse of 1D pixel shuffler
33+ Upscales channel length, downscales sample length
34+ "long" is input, "short" is output
35+ """
36+ def __init__ (self , downscale_factor ):
37+ super (PixelUnshuffle1D , self ).__init__ ()
38+ self .downscale_factor = downscale_factor
39+
40+ def forward (self , x ):
41+ batch_size = x .shape [0 ]
42+ long_channel_len = x .shape [1 ]
43+ long_width = x .shape [2 ]
44+
45+ short_channel_len = long_channel_len * self .downscale_factor
46+ short_width = long_width // self .downscale_factor
47+
48+ x = x .contiguous ().view ([batch_size , long_channel_len , short_width , self .downscale_factor ])
49+ x = x .permute (0 , 3 , 1 , 2 ).contiguous ()
50+ x = x .view ([batch_size , short_channel_len , short_width ])
51+ return x
0 commit comments