Skip to content

Commit ac0bf43

Browse files
authored
Merge pull request #52 from dennisprangle/custom_activation
Allow scales >1.001 in AffineCouplingTransform (fixes #49)
2 parents 639c3a7 + cd6e3b7 commit ac0bf43

File tree

2 files changed

+26
-2
lines changed

2 files changed

+26
-2
lines changed

nflows/transforms/coupling.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import torch
6+
from torch.nn.functional import softplus
67

78
from nflows.transforms import splines
89
from nflows.transforms.base import Transform
@@ -213,16 +214,27 @@ class AffineCouplingTransform(CouplingTransform):
213214
214215
Reference:
215216
> L. Dinh et al., Density estimation using Real NVP, ICLR 2017.
217+
218+
The user should supply `scale_activation`, the final activation function in the neural network producing the scale tensor.
219+
Two options are predefined in the class.
220+
`DEFAULT_SCALE_ACTIVATION` preserves backwards compatibility but only produces scales <= 1.001.
221+
`GENERAL_SCALE_ACTIVATION` produces scales <= 3, which is more useful in general applications.
216222
"""
217223

224+
DEFAULT_SCALE_ACTIVATION = lambda x : torch.sigmoid(x + 2) + 1e-3
225+
GENERAL_SCALE_ACTIVATION = lambda x : (softplus(x) + 1e-3).clamp(0, 3)
226+
227+
def __init__(self, mask, transform_net_create_fn, unconditional_transform=None, scale_activation=DEFAULT_SCALE_ACTIVATION):
228+
self.scale_activation = scale_activation
229+
super().__init__(mask, transform_net_create_fn, unconditional_transform)
230+
218231
def _transform_dim_multiplier(self):
219232
return 2
220233

221234
def _scale_and_shift(self, transform_params):
222235
unconstrained_scale = transform_params[:, self.num_transform_features:, ...]
223236
shift = transform_params[:, : self.num_transform_features, ...]
224-
# scale = (F.softplus(unconstrained_scale) + 1e-3).clamp(0, 3)
225-
scale = torch.sigmoid(unconstrained_scale + 2) + 1e-3
237+
scale = self.scale_activation(unconstrained_scale)
226238
return scale, shift
227239

228240
def _coupling_transform_forward(self, inputs, transform_params):

tests/transforms/coupling_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,18 @@ def test_forward_inverse_are_consistent(self):
7171
with self.subTest(shape=shape):
7272
self.assert_forward_inverse_are_consistent(transform, inputs)
7373

74+
def test_scale_activation_has_an_effect(self):
75+
for shape in self.shapes:
76+
inputs = torch.randn(batch_size, *shape)
77+
transform, mask = create_coupling_transform(
78+
coupling.AffineCouplingTransform, shape
79+
)
80+
outputs_default, logabsdet_default = transform(inputs)
81+
transform.scale_activation = coupling.AffineCouplingTransform.GENERAL_SCALE_ACTIVATION
82+
outputs_general, logabsdet_general = transform(inputs)
83+
with self.subTest(shape=shape):
84+
self.assertNotEqual(outputs_default, outputs_general)
85+
self.assertNotEqual(logabsdet_default, logabsdet_general)
7486

7587
class AdditiveTransformTest(TransformTest):
7688
shapes = [[20], [2, 4, 4]]

0 commit comments

Comments
 (0)