Skip to content

Commit 025ca4b

Browse files
committed
add ResizeFeatures function
1 parent 94fba5f commit 025ca4b

File tree

6 files changed

+231
-37
lines changed

6 files changed

+231
-37
lines changed

setup.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
long_description = fh.read()
55

66
setuptools.setup(
7-
name="INNLab", # Replace with your own username
8-
version="0.0.2",
7+
name="INNLab",
8+
version="0.1.0",
99
author="Yanbo Zhang",
1010
author_email="zhangybspm@gmail.com",
1111
description="A package for invertible neural networks",

src/INN/INN.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -159,9 +159,10 @@ def PixelUnshuffle(self, x):
159159

160160

161161
class BatchNorm1d(nn.BatchNorm1d, INNAbstract.INNModule):
162-
def __init__(self, dim):
162+
def __init__(self, dim, requires_grad=False):
163163
INNAbstract.INNModule.__init__(self)
164164
nn.BatchNorm1d.__init__(self, num_features=dim, affine=False)
165+
self.requires_grad = requires_grad
165166

166167
def forward(self, x, log_p=0, log_det_J=0):
167168

@@ -171,7 +172,9 @@ def forward(self, x, log_p=0, log_det_J=0):
171172
var = self.running_var # [dim]
172173
else:
173174
# if in training
174-
var = torch.var(x, dim=0, unbiased=False).detach() # [dim]
175+
var = torch.var(x, dim=0, unbiased=False)#.detach() # [dim]
176+
if not self.requires_grad:
177+
var = var.detach()
175178

176179
x = super(BatchNorm1d, self).forward(x)
177180

@@ -211,11 +214,11 @@ def inverse(self, y, **args):
211214

212215
class RealNVP(INNAbstract.INNModule):
213216

214-
def __init__(self, dim=None, f_log_s=None, f_t=None, k=4, mask=None, clip=1):
217+
def __init__(self, dim=None, f_log_s=None, f_t=None, k=4, mask=None, clip=1, activation_fn=None):
215218
super(RealNVP, self).__init__()
216219
if (f_log_s is None) and (f_t is None):
217-
log_s = utilities.default_net(dim, k)#self.default_net(dim, k)
218-
t = utilities.default_net(dim, k)#self.default_net(dim, k)
220+
log_s = utilities.default_net(dim, k, activation_fn)#self.default_net(dim, k)
221+
t = utilities.default_net(dim, k, activation_fn)#self.default_net(dim, k)
219222
self.net = utilities.combined_real_nvp(dim, log_s, t, mask, clip)
220223
else:
221224
self.net = utilities.combined_real_nvp(dim, f_log_s, f_t, mask, clip)
@@ -234,11 +237,11 @@ def inverse(self, y, **args):
234237

235238
class NICE(INNAbstract.INNModule):
236239

237-
def __init__(self, dim=None, m=None, mask=None, k=4):
240+
def __init__(self, dim=None, m=None, mask=None, k=4, activation_fn=None):
238241
super(NICE, self).__init__()
239242

240243
if m is None:
241-
m_ = utilities.default_net(dim, k)
244+
m_ = utilities.default_net(dim, k, activation_fn)
242245
self.net = utilities.NICE(dim, m=m_, mask=mask)
243246
else:
244247
self.net = utilities.NICE(dim, m=m, mask=mask)
@@ -268,18 +271,18 @@ class Nonlinear(INNAbstract.INNModule):
268271
'''
269272
Nonlinear invertible block
270273
'''
271-
def __init__(self, dim, method='NICE', m=None, mask=None, k=4, **args):
274+
def __init__(self, dim, method='RealNVP', m=None, mask=None, k=4, activation_fn=None, **args):
272275
super(Nonlinear, self).__init__()
273276

274277
self.method = method
275278
if method == 'NICE':
276-
self.block = NICE(dim, m=m, mask=mask, k=k)
279+
self.block = NICE(dim, m=m, mask=mask, k=k, activation_fn=activation_fn)
277280
if method == 'RealNVP':
278281
clip = _default_dict('clip', args, 1)
279282
f_log_s = _default_dict('f_log_s', args, None)
280283
f_t = _default_dict('f_t', args, None)
281284

282-
self.block = RealNVP(dim=dim, f_log_s=f_log_s, f_t=f_t, k=k, mask=mask, clip=clip)
285+
self.block = RealNVP(dim=dim, f_log_s=f_log_s, f_t=f_t, k=k, mask=mask, clip=clip, activation_fn=activation_fn)
283286
if method == 'iResNet':
284287
g = _default_dict('g', args, None)
285288
beta = _default_dict('beta', args, 0.8)
@@ -293,4 +296,70 @@ def forward(self, x, log_p0=0, log_det_J=0):
293296
return self.block(x, log_p0, log_det_J)
294297

295298
def inverse(self, y, **args):
296-
return self.block.inverse(y, **args)
299+
return self.block.inverse(y, **args)
300+
301+
class ResizeFeatures(INNAbstract.INNModule):
302+
'''
303+
Resize for n-d input, include linear or multi-channel inputs
304+
'''
305+
def __init__(self, feature_in, feature_out, dist='normal'):
306+
super(ResizeFeatures, self).__init__()
307+
self.feature_in = feature_in
308+
self.feature_out = feature_out
309+
310+
if dist == 'normal':
311+
self.dist = utilities.NormalDistribution()
312+
elif isinstance(dist, INNAbstract.Distribution):
313+
self.dist = dist
314+
315+
def resize(self, x, feature_in, feature_out):
316+
'''
317+
x has two kinds of shapes:
318+
1. [feature_in]
319+
2. [batch_size, feature_in, *]
320+
'''
321+
if len(x.shape) == 1:
322+
# [feature_in]
323+
if x.shape[0] != self.feature_in:
324+
raise Exception(f'Expect to get {self.feature_in} features, but got {x.shape[0]}.')
325+
y, z = x[:feature_out], x[feature_out:]
326+
327+
if len(x.shape) >= 2:
328+
# [batch_size, feature_in, *]
329+
if x.shape[1] != self.feature_in:
330+
raise Exception(f'Expect to get {self.feature_in} features, but got {x.shape[1]}.')
331+
y, z = x[:, :feature_out], x[:, feature_out:]
332+
333+
return y, z
334+
335+
def forward(self, x, log_p0=0, log_det_J=0):
336+
x, z = self.resize(x, self.feature_in, self.feature_out)
337+
if self.compute_p:
338+
p = self.dist.logp(z)
339+
return x, log_p0 + p, log_det_J
340+
else:
341+
return x
342+
343+
def inverse(self, y, **args):
344+
'''
345+
y has two kinds of shapes:
346+
1. [feature_in]
347+
2. [batch_size, feature_in, *]
348+
'''
349+
if len(y.shape) == 1:
350+
# [feature_in]
351+
if y.shape[0] != self.feature_out:
352+
raise Exception(f'Expect to get {self.feature_out} features, but got {y.shape[0]}.')
353+
z = self.dist.sample(self.feature_in-self.feature_out).to(y.device)
354+
y = torch.cat([y, z])
355+
356+
if len(y.shape) >= 2:
357+
# [batch_size, feature_in, *]
358+
if y.shape[1] != self.feature_out:
359+
raise Exception(f'Expect to get {self.feature_out} features, but got {y.shape[1]}.')
360+
shape = list(y.shape)
361+
shape[1] = self.feature_in-self.feature_out
362+
z = self.dist.sample(shape).to(y.device)
363+
y = torch.cat([y, z], dim=1)
364+
365+
return y

src/INN/INNAbstract.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,4 +140,21 @@ def forward(self, x, log_p0, log_det_J):
140140
return self.PixelUnshuffle(x)
141141

142142
def inverse(self, y, num_iter=100):
143-
return self.PixelShuffle(y)
143+
return self.PixelShuffle(y)
144+
145+
146+
class Distribution(nn.Module):
147+
148+
def __init__(self):
149+
super(Distribution, self).__init__()
150+
151+
def logp(self, x):
152+
raise NotImplementedError('logp() not implemented')
153+
154+
def sample(self, shape):
155+
raise NotImplementedError('sample() not implemented')
156+
157+
def forward(self, x):
158+
x = self.logp(x)
159+
160+
return x

src/INN/utilities.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def forward(self, x):
138138
return x
139139

140140

141-
class NormalDistribution(nn.Module):
141+
class NormalDistribution(INNAbstract.Distribution):
142142
'''
143143
Generate normal distribution and compute log probablity
144144
'''
@@ -166,10 +166,6 @@ def logp(self, x):
166166
def sample(self, shape):
167167
return torch.randn(shape)
168168

169-
def forward(self, x):
170-
x = self.logp(x)
171-
172-
return x
173169

174170
def permutation_matrix(dim):
175171
# generate a permuation matrix
@@ -396,21 +392,38 @@ def inverse(self, y):
396392

397393

398394
class default_net(nn.Module):
399-
def __init__(self, dim, k):
395+
def __init__(self, dim, k, activation_fn=None):
400396
super(default_net, self).__init__()
401-
self.net = self.default_net(dim, k)
397+
self.activation_fn = activation_fn
398+
self.net = self.default_net(dim, k, activation_fn)
402399

403-
def default_net(self, dim, k):
404-
block = nn.Sequential(nn.Linear(dim, k * dim), nn.LeakyReLU(),
405-
nn.Linear(k * dim, k * dim), nn.LeakyReLU(),
400+
def default_net(self, dim, k, activation_fn):
401+
if activation_fn == None:
402+
ac = nn.LeakyReLU
403+
else:
404+
ac = activation_fn
405+
406+
block = nn.Sequential(nn.Linear(dim, k * dim), ac(),
407+
nn.Linear(k * dim, k * dim), ac(),
406408
nn.Linear(k * dim, dim))
407409
block.apply(self.init_weights)
408410
return block
409411

410412
def init_weights(self, m):
413+
nonlinearity = 'leaky_relu' # set to leaky_relu by default
414+
415+
if self.activation_fn is nn.ReLU:
416+
nonlinearity = 'leaky_relu'
417+
if self.activation_fn is nn.SELU:
418+
nonlinearity = 'selu'
419+
if self.activation_fn is nn.Tanh:
420+
nonlinearity = 'tanh'
421+
if self.activation_fn is nn.Sigmoid:
422+
nonlinearity = 'sigmoid'
423+
411424
if type(m) == nn.Linear:
412425
# doing Kaiming initialization
413-
torch.nn.init.kaiming_normal_(m.weight.data, nonlinearity='leaky_relu')
426+
torch.nn.init.kaiming_normal_(m.weight.data, nonlinearity=nonlinearity)
414427
torch.nn.init.zeros_(m.bias.data)
415428

416429
def forward(self, x):

tests/quick_tests.ipynb

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,96 @@
461461
"bn(x)"
462462
]
463463
},
464+
{
465+
"cell_type": "code",
466+
"execution_count": 24,
467+
"metadata": {
468+
"ExecuteTime": {
469+
"end_time": "2021-04-27T00:31:48.829977Z",
470+
"start_time": "2021-04-27T00:31:48.826448Z"
471+
}
472+
},
473+
"outputs": [],
474+
"source": [
475+
"x = torch.randn((3,3,3))"
476+
]
477+
},
478+
{
479+
"cell_type": "code",
480+
"execution_count": 29,
481+
"metadata": {
482+
"ExecuteTime": {
483+
"end_time": "2021-04-27T00:32:17.029541Z",
484+
"start_time": "2021-04-27T00:32:17.023938Z"
485+
}
486+
},
487+
"outputs": [
488+
{
489+
"data": {
490+
"text/plain": [
491+
"[3, 3, 3]"
492+
]
493+
},
494+
"execution_count": 29,
495+
"metadata": {},
496+
"output_type": "execute_result"
497+
}
498+
],
499+
"source": [
500+
"list(x.shape)"
501+
]
502+
},
503+
{
504+
"cell_type": "code",
505+
"execution_count": 26,
506+
"metadata": {
507+
"ExecuteTime": {
508+
"end_time": "2021-04-27T00:31:57.709249Z",
509+
"start_time": "2021-04-27T00:31:57.701963Z"
510+
}
511+
},
512+
"outputs": [
513+
{
514+
"ename": "AttributeError",
515+
"evalue": "attribute 'shape' of 'torch._C._TensorBase' objects is not writable",
516+
"output_type": "error",
517+
"traceback": [
518+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
519+
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
520+
"\u001b[0;32m<ipython-input-26-2924dc973659>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m5\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
521+
"\u001b[0;31mAttributeError\u001b[0m: attribute 'shape' of 'torch._C._TensorBase' objects is not writable"
522+
]
523+
}
524+
],
525+
"source": [
526+
"x.shape = 5"
527+
]
528+
},
529+
{
530+
"cell_type": "code",
531+
"execution_count": 31,
532+
"metadata": {
533+
"ExecuteTime": {
534+
"end_time": "2021-04-27T00:39:07.586468Z",
535+
"start_time": "2021-04-27T00:39:07.581351Z"
536+
}
537+
},
538+
"outputs": [
539+
{
540+
"data": {
541+
"text/plain": [
542+
"[5]"
543+
]
544+
},
545+
"execution_count": 31,
546+
"metadata": {},
547+
"output_type": "execute_result"
548+
}
549+
],
550+
"source": [
551+
"[1,2,3,4,5][4:]"
552+
]
553+
},
464554
{
465555
"cell_type": "code",
466556
"execution_count": null,

0 commit comments

Comments
 (0)