Skip to content

Commit 22ad802

Browse files
committed
change initialize control of linear realNVP from True/False to scale
1 parent 129f796 commit 22ad802

File tree

2 files changed

+7
-9
lines changed

2 files changed

+7
-9
lines changed

INN/CouplingModels/RealNVP/linear.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44

55

66
class NonlinearRealNVP(INNModule):
7-
def __init__(self, dim=None, f_log_s=None, f_t=None, k=4, mask=None, clip=1, activation_fn=None):
7+
def __init__(self, dim=None, f_log_s=None, f_t=None, k=4, mask=None, clip=1, activation_fn=None, scale=0.01):
88
super(NonlinearRealNVP, self).__init__()
99
self.dim = dim
1010

1111
if f_log_s is None:
12-
f_log_s = coupling_utils.default_nonlinear_net(dim, k, activation_fn, zero=True)
12+
f_log_s = coupling_utils.default_nonlinear_net(dim, k, activation_fn, scale=scale)
1313
if f_t is None:
14-
f_t = coupling_utils.default_nonlinear_net(dim, k, activation_fn)
14+
f_t = coupling_utils.default_nonlinear_net(dim, k, activation_fn, scale=scale*10)
1515

1616
self.net = utils.combined_real_nvp(dim, f_log_s, f_t, mask, clip)
1717

INN/CouplingModels/utils.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33

44

55
class default_nonlinear_net(nn.Module):
6-
def __init__(self, dim, k, activation_fn=None, zero=False):
6+
def __init__(self, dim, k, activation_fn=None, scale=0.01):
77
super(default_nonlinear_net, self).__init__()
8-
self._zero_init = zero
8+
self.scale = scale
99
self.activation_fn = activation_fn
1010
self.net = self.default_net(dim, k, activation_fn)
1111

@@ -25,10 +25,8 @@ def init_weights(self, m):
2525
if type(m) == nn.Linear:
2626
# doing xavier initialization
2727
# NOTE: Kaiming initialization will make the output too high, which leads to nan
28-
if self._zero_init:
29-
torch.nn.init.zeros_(m.weight.data)
30-
else:
31-
torch.nn.init.xavier_normal_(m.weight.data)
28+
torch.nn.init.xavier_normal_(m.weight.data)
29+
m.weight.data *= self.scale
3230
torch.nn.init.zeros_(m.bias.data)
3331

3432
def forward(self, x):

0 commit comments

Comments
 (0)