|
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 | import torch |
| 6 | +from torch.nn.functional import softplus |
6 | 7 |
|
7 | 8 | from nflows.transforms import splines |
8 | 9 | from nflows.transforms.base import Transform |
@@ -213,16 +214,27 @@ class AffineCouplingTransform(CouplingTransform): |
213 | 214 |
|
214 | 215 | Reference: |
215 | 216 | > L. Dinh et al., Density estimation using Real NVP, ICLR 2017. |
| 217 | +
|
| 218 | + The user should supply `scale_activation`, the final activation function in the neural network producing the scale tensor. |
| 219 | + Two options are predefined in the class. |
| 220 | + `DEFAULT_SCALE_ACTIVATION` preserves backwards compatibility but only produces scales <= 1.001. |
| 221 | + `GENERAL_SCALE_ACTIVATION` produces scales <= 3, which is more useful in general applications. |
216 | 222 | """ |
217 | 223 |
|
| 224 | + DEFAULT_SCALE_ACTIVATION = lambda x : torch.sigmoid(x + 2) + 1e-3 |
| 225 | + GENERAL_SCALE_ACTIVATION = lambda x : (softplus(x) + 1e-3).clamp(0, 3) |
| 226 | + |
| 227 | + def __init__(self, mask, transform_net_create_fn, unconditional_transform=None, scale_activation=DEFAULT_SCALE_ACTIVATION): |
| 228 | + self.scale_activation = scale_activation |
| 229 | + super().__init__(mask, transform_net_create_fn, unconditional_transform) |
| 230 | + |
218 | 231 | def _transform_dim_multiplier(self): |
219 | 232 | return 2 |
220 | 233 |
|
221 | 234 | def _scale_and_shift(self, transform_params): |
222 | 235 | unconstrained_scale = transform_params[:, self.num_transform_features:, ...] |
223 | 236 | shift = transform_params[:, : self.num_transform_features, ...] |
224 | | - # scale = (F.softplus(unconstrained_scale) + 1e-3).clamp(0, 3) |
225 | | - scale = torch.sigmoid(unconstrained_scale + 2) + 1e-3 |
| 237 | + scale = self.scale_activation(unconstrained_scale) |
226 | 238 | return scale, shift |
227 | 239 |
|
228 | 240 | def _coupling_transform_forward(self, inputs, transform_params): |
|
0 commit comments