Skip to content

Commit 4099326

Browse files
committed
PicelShuffle1d added
1 parent 051772c commit 4099326

File tree

6 files changed

+808
-70
lines changed

6 files changed

+808
-70
lines changed

src/INN/INN.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import INN.INNAbstract as INNAbstract
1010
import INN.cnn as cnn
1111
import torch.nn.functional as F
12+
import INN.pixel_shuffle_1d as ps
1213

1314
iResNetModule = INNAbstract.iResNetModule
1415

@@ -134,6 +135,24 @@ def PixelUnshuffle(self, x):
134135
return self.unshuffle(x)
135136

136137

138+
class PixelShuffle1d(INNAbstract.PixelShuffleModule):
139+
'''
140+
2d invertible pixel shuffle, using the built-in method
141+
from pytorch. (nn.PixelShuffle, and nn.PixelUnshuffle)
142+
'''
143+
def __init__(self, r):
144+
super(PixelShuffle1d, self).__init__()
145+
self.r = r
146+
self.shuffle = ps.PixelShuffle1D(r)
147+
self.unshuffle = ps.PixelUnshuffle1D(r)
148+
149+
def PixelShuffle(self, x):
150+
return self.shuffle(x)
151+
152+
def PixelUnshuffle(self, x):
153+
return self.unshuffle(x)
154+
155+
137156
class BatchNorm1d(nn.BatchNorm1d, INNAbstract.INNModule):
138157
def __init__(self, dim, requires_grad=True):
139158
INNAbstract.INNModule.__init__(self)

src/INN/INNAbstract.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,14 +131,14 @@ def PixelShuffle(self, x):
131131
def PixelUnshuffle(self, x):
132132
pass
133133

134-
def forward(self, x, log_p0, log_det_J):
134+
def forward(self, x, log_p0=0, log_det_J=0):
135135
# The log(p_0) and log|det J| will not change under this transformation
136136
if self.compute_p:
137137
return self.PixelUnshuffle(x), log_p0, log_det_J
138138
else:
139139
return self.PixelUnshuffle(x)
140140

141-
def inverse(self, y, num_iter=100):
141+
def inverse(self, y, **args):
142142
return self.PixelShuffle(y)
143143

144144

src/INN/pixel_shuffle_1d.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import torch
2+
3+
# source: https://github.com/serkansulun/pytorch-pixelshuffle1d/blob/master/pixelshuffle1d.py
4+
# "long" and "short" denote longer and shorter samples
5+
6+
class PixelShuffle1D(torch.nn.Module):
7+
"""
8+
1D pixel shuffler. https://arxiv.org/pdf/1609.05158.pdf
9+
Upscales sample length, downscales channel length
10+
"short" is input, "long" is output
11+
"""
12+
def __init__(self, upscale_factor):
13+
super(PixelShuffle1D, self).__init__()
14+
self.upscale_factor = upscale_factor
15+
16+
def forward(self, x):
17+
batch_size = x.shape[0]
18+
short_channel_len = x.shape[1]
19+
short_width = x.shape[2]
20+
21+
long_channel_len = short_channel_len // self.upscale_factor
22+
long_width = self.upscale_factor * short_width
23+
24+
x = x.contiguous().view([batch_size, self.upscale_factor, long_channel_len, short_width])
25+
x = x.permute(0, 2, 3, 1).contiguous()
26+
x = x.view(batch_size, long_channel_len, long_width)
27+
28+
return x
29+
30+
class PixelUnshuffle1D(torch.nn.Module):
31+
"""
32+
Inverse of 1D pixel shuffler
33+
Upscales channel length, downscales sample length
34+
"long" is input, "short" is output
35+
"""
36+
def __init__(self, downscale_factor):
37+
super(PixelUnshuffle1D, self).__init__()
38+
self.downscale_factor = downscale_factor
39+
40+
def forward(self, x):
41+
batch_size = x.shape[0]
42+
long_channel_len = x.shape[1]
43+
long_width = x.shape[2]
44+
45+
short_channel_len = long_channel_len * self.downscale_factor
46+
short_width = long_width // self.downscale_factor
47+
48+
x = x.contiguous().view([batch_size, long_channel_len, short_width, self.downscale_factor])
49+
x = x.permute(0, 3, 1, 2).contiguous()
50+
x = x.view([batch_size, short_channel_len, short_width])
51+
return x

0 commit comments

Comments
 (0)