Skip to content

Commit 616c3ee

Browse files
committed
Pull request for adding UMNN to this repository. Implemented coupling and autoregressive version. Simple unit tests passed and adapted notebook examples work too.
1 parent 75048ff commit 616c3ee

File tree

8 files changed

+287
-1
lines changed

8 files changed

+287
-1
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies:
2121
- pip:
2222
- torchtestcase
2323
- -e . # install package in development mode
24+
- umnn
2425
- pytest
2526
- python
2627
- pytorch
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torch
2+
from UMNN import NeuralIntegral, ParallelNeuralIntegral
3+
import torch.nn as nn
4+
5+
6+
def _flatten(sequence):
7+
flat = [p.contiguous().view(-1) for p in sequence]
8+
return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
9+
10+
11+
class ELUPlus(nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
self.elu = nn.ELU()
15+
16+
def forward(self, x):
17+
return self.elu(x) + 1.
18+
19+
20+
class IntegrandNet(nn.Module):
21+
def __init__(self, hidden, cond_in):
22+
super(IntegrandNet, self).__init__()
23+
l1 = [1 + cond_in] + hidden
24+
l2 = hidden + [1]
25+
layers = []
26+
for h1, h2 in zip(l1, l2):
27+
layers += [nn.Linear(h1, h2), nn.ReLU()]
28+
layers.pop()
29+
layers.append(ELUPlus())
30+
self.net = nn.Sequential(*layers)
31+
32+
def forward(self, x, h):
33+
nb_batch, in_d = x.shape
34+
x = torch.cat((x, h), 1)
35+
x_he = x.view(nb_batch, -1, in_d).transpose(1, 2).contiguous().view(nb_batch * in_d, -1)
36+
y = self.net(x_he).view(nb_batch, -1)
37+
return y
38+
39+
40+
class MonotonicNormalizer(nn.Module):
41+
def __init__(self, integrand_net, cond_size, nb_steps=20, solver="CC"):
42+
super(MonotonicNormalizer, self).__init__()
43+
if type(integrand_net) is list:
44+
self.integrand_net = IntegrandNet(integrand_net, cond_size)
45+
else:
46+
self.integrand_net = integrand_net
47+
self.solver = solver
48+
self.nb_steps = nb_steps
49+
50+
def forward(self, x, h, context=None):
51+
x0 = torch.zeros(x.shape).to(x.device)
52+
xT = x
53+
z0 = h[:, :, 0]
54+
h = h.permute(0, 2, 1).contiguous().view(x.shape[0], -1)
55+
if self.solver == "CC":
56+
z = NeuralIntegral.apply(x0, xT, self.integrand_net, _flatten(self.integrand_net.parameters()),
57+
h, self.nb_steps) + z0
58+
elif self.solver == "CCParallel":
59+
z = ParallelNeuralIntegral.apply(x0, xT, self.integrand_net,
60+
_flatten(self.integrand_net.parameters()),
61+
h, self.nb_steps) + z0
62+
else:
63+
return None
64+
return z, self.integrand_net(x, h)
65+
66+
def inverse_transform(self, z, h, context=None):
67+
# Old inversion by binary search
68+
x_max = torch.ones_like(z) * 20
69+
x_min = -torch.ones_like(z) * 20
70+
z_max, _ = self.forward(x_max, h, context)
71+
z_min, _ = self.forward(x_min, h, context)
72+
for i in range(25):
73+
x_middle = (x_max + x_min) / 2
74+
z_middle, _ = self.forward(x_middle, h, context)
75+
left = (z_middle > z).float()
76+
right = 1 - left
77+
x_max = left * x_middle + right * x_max
78+
x_min = right * x_middle + left * x_min
79+
z_max = left * z_middle + right * z_max
80+
z_min = right * z_middle + left * z_min
81+
return (x_max + x_min) / 2
82+
83+

nflows/transforms/UMNN/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from nflows.transforms.UMNN.MonotonicNormalizer import MonotonicNormalizer, IntegrandNet

nflows/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
MaskedPiecewiseLinearAutoregressiveTransform,
55
MaskedPiecewiseQuadraticAutoregressiveTransform,
66
MaskedPiecewiseRationalQuadraticAutoregressiveTransform,
7+
MaskedUMNNAutoregressiveTransform,
78
)
89
from nflows.transforms.base import (
910
CompositeTransform,
@@ -21,6 +22,7 @@
2122
PiecewiseLinearCouplingTransform,
2223
PiecewiseQuadraticCouplingTransform,
2324
PiecewiseRationalQuadraticCouplingTransform,
25+
UMNNCouplingTransform,
2426
)
2527
from nflows.transforms.linear import NaiveLinear
2628
from nflows.transforms.lu import LULinear

nflows/transforms/autoregressive.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
unconstrained_rational_quadratic_spline,
1919
)
2020
from nflows.utils import torchutils
21+
from nflows.transforms.UMNN import *
2122

2223

2324
class AutoregressiveTransform(Transform):
@@ -127,6 +128,71 @@ def _unconstrained_scale_and_shift(self, autoregressive_params):
127128
return unconstrained_scale, shift
128129

129130

131+
class MaskedUMNNAutoregressiveTransform(AutoregressiveTransform):
132+
"""An unconstrained monotonic neural networks autoregressive layer that transforms the variables.
133+
134+
Reference:
135+
> A. Wehenkel and G. Louppe, Unconstrained Monotonic Neural Networks, NeurIPS2019.
136+
137+
---- Specific arguments ----
138+
integrand_net_layers: the layers dimension to put in the integrand network.
139+
cond_size: The embedding size for the conditioning factors.
140+
nb_steps: The number of integration steps.
141+
solver: The quadrature algorithm - CC or CCParallel. Both implements Clenshaw-Curtis quadrature with
142+
Leibniz rule for backward computation. CCParallel pass all the evaluation points (nb_steps) at once, it is faster
143+
but requires more memory.
144+
"""
145+
def __init__(
146+
self,
147+
features,
148+
hidden_features,
149+
context_features=None,
150+
num_blocks=2,
151+
use_residual_blocks=True,
152+
random_mask=False,
153+
activation=F.relu,
154+
dropout_probability=0.0,
155+
use_batch_norm=False,
156+
integrand_net_layers=[50, 50, 50],
157+
cond_size=20,
158+
nb_steps=20,
159+
solver="CCParallel",
160+
):
161+
self.features = features
162+
self.cond_size = cond_size
163+
made = made_module.MADE(
164+
features=features,
165+
hidden_features=hidden_features,
166+
context_features=context_features,
167+
num_blocks=num_blocks,
168+
output_multiplier=self._output_dim_multiplier(),
169+
use_residual_blocks=use_residual_blocks,
170+
random_mask=random_mask,
171+
activation=activation,
172+
dropout_probability=dropout_probability,
173+
use_batch_norm=use_batch_norm,
174+
)
175+
self._epsilon = 1e-3
176+
super().__init__(made)
177+
self.transformer = MonotonicNormalizer(integrand_net_layers, cond_size, nb_steps, solver)
178+
179+
180+
def _output_dim_multiplier(self):
181+
return self.cond_size
182+
183+
def _elementwise_forward(self, inputs, autoregressive_params):
184+
z, jac = self.transformer(inputs, autoregressive_params.reshape(inputs.shape[0], inputs.shape[1], -1))
185+
log_det_jac = jac.log().sum(1)
186+
return z, log_det_jac
187+
188+
def _elementwise_inverse(self, inputs, autoregressive_params):
189+
x = self.transformer.inverse_transform(inputs, autoregressive_params.reshape(inputs.shape[0], inputs.shape[1], -1))
190+
z, jac = self.transformer(x, autoregressive_params.reshape(inputs.shape[0], inputs.shape[1], -1))
191+
log_det_jac = -jac.log().sum(1)
192+
return x, log_det_jac
193+
194+
195+
130196
class MaskedPiecewiseLinearAutoregressiveTransform(AutoregressiveTransform):
131197
def __init__(
132198
self,

nflows/transforms/coupling.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
PiecewiseRationalQuadraticCDF,
1414
)
1515
from nflows.utils import torchutils
16+
from nflows.transforms.UMNN import *
1617

1718

1819
class CouplingTransform(Transform):
@@ -140,6 +141,73 @@ def _coupling_transform_inverse(self, inputs, transform_params):
140141
raise NotImplementedError()
141142

142143

144+
class UMNNCouplingTransform(CouplingTransform):
145+
"""An unconstrained monotonic neural networks coupling layer that transforms the variables.
146+
147+
Reference:
148+
> A. Wehenkel and G. Louppe, Unconstrained Monotonic Neural Networks, NeurIPS2019.
149+
150+
---- Specific arguments ----
151+
integrand_net_layers: the layers dimension to put in the integrand network.
152+
cond_size: The embedding size for the conditioning factors.
153+
nb_steps: The number of integration steps.
154+
solver: The quadrature algorithm - CC or CCParallel. Both implements Clenshaw-Curtis quadrature with
155+
Leibniz rule for backward computation. CCParallel pass all the evaluation points (nb_steps) at once, it is faster
156+
but requires more memory.
157+
158+
"""
159+
def __init__(
160+
self,
161+
mask,
162+
transform_net_create_fn,
163+
integrand_net_layers=[50, 50, 50],
164+
cond_size=20,
165+
nb_steps=20,
166+
solver="CCParallel",
167+
apply_unconditional_transform=False
168+
):
169+
170+
if apply_unconditional_transform:
171+
unconditional_transform = lambda features: MonotonicNormalizer(integrand_net_layers, 0, nb_steps, solver)
172+
else:
173+
unconditional_transform = None
174+
self.cond_size = cond_size
175+
super().__init__(
176+
mask,
177+
transform_net_create_fn,
178+
unconditional_transform=unconditional_transform,
179+
)
180+
181+
self.transformer = MonotonicNormalizer(integrand_net_layers, cond_size, nb_steps, solver)
182+
183+
def _transform_dim_multiplier(self):
184+
return self.cond_size
185+
186+
def _coupling_transform_forward(self, inputs, transform_params):
187+
if len(inputs.shape) == 2:
188+
z, jac = self.transformer(inputs, transform_params.reshape(inputs.shape[0], inputs.shape[1], -1))
189+
log_det_jac = jac.log().sum(1)
190+
return z, log_det_jac
191+
else:
192+
B, C, H, W = inputs.shape
193+
z, jac = self.transformer(inputs.permute(0, 2, 3, 1).reshape(-1, inputs.shape[1]), transform_params.permute(0, 2, 3, 1).reshape(-1, 1, transform_params.shape[1]))
194+
log_det_jac = jac.log().reshape(B, -1).sum(1)
195+
return z.reshape(B, H, W, C).permute(0, 3, 1, 2), log_det_jac
196+
197+
def _coupling_transform_inverse(self, inputs, transform_params):
198+
if len(inputs.shape) == 2:
199+
x = self.transformer.inverse_transform(inputs, transform_params.reshape(inputs.shape[0], inputs.shape[1], -1))
200+
z, jac = self.transformer(x, transform_params.reshape(inputs.shape[0], inputs.shape[1], -1))
201+
log_det_jac = -jac.log().sum(1)
202+
return x, log_det_jac
203+
else:
204+
B, C, H, W = inputs.shape
205+
x = self.transformer.inverse_transform(inputs.permute(0, 2, 3, 1).reshape(-1, inputs.shape[1]), transform_params.permute(0, 2, 3, 1).reshape(-1, 1, transform_params.shape[1]))
206+
z, jac = self.transformer(x, transform_params.permute(0, 2, 3, 1).reshape(-1, 1, transform_params.shape[1]))
207+
log_det_jac = -jac.log().reshape(B, -1).sum(1)
208+
return x.reshape(B, H, W, C).permute(0, 3, 1, 2), log_det_jac
209+
210+
143211
class AffineCouplingTransform(CouplingTransform):
144212
"""An affine coupling layer that scales and shifts part of the variables.
145213
@@ -151,7 +219,7 @@ def _transform_dim_multiplier(self):
151219
return 2
152220

153221
def _scale_and_shift(self, transform_params):
154-
unconstrained_scale = transform_params[:, self.num_transform_features :, ...]
222+
unconstrained_scale = transform_params[:, self.num_transform_features:, ...]
155223
shift = transform_params[:, : self.num_transform_features, ...]
156224
# scale = (F.softplus(unconstrained_scale) + 1e-3).clamp(0, 3)
157225
scale = torch.sigmoid(unconstrained_scale + 2) + 1e-3

tests/transforms/autoregressive_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,24 @@ def test_forward_inverse_are_consistent(self):
114114
self.assert_forward_inverse_are_consistent(transform, inputs)
115115

116116

117+
class MaskedUMNNAutoregressiveTranformTest(TransformTest):
118+
def test_forward_inverse_are_consistent(self):
119+
batch_size = 10
120+
features = 20
121+
inputs = torch.rand(batch_size, features)
122+
self.eps = 1e-4
123+
124+
transform = autoregressive.MaskedUMNNAutoregressiveTransform(
125+
cond_size=10,
126+
features=features,
127+
hidden_features=30,
128+
num_blocks=5,
129+
use_residual_blocks=True,
130+
)
131+
132+
self.assert_forward_inverse_are_consistent(transform, inputs)
133+
134+
117135
class MaskedPiecewiseCubicAutoregressiveTranformTest(TransformTest):
118136
def test_forward_inverse_are_consistent(self):
119137
batch_size = 10

tests/transforms/coupling_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,53 @@ def test_forward_inverse_are_consistent(self):
112112
self.assert_forward_inverse_are_consistent(transform, inputs)
113113

114114

115+
class UMNNTransformTest(TransformTest):
116+
shapes = [[20], [2, 4, 4]]
117+
118+
def test_forward(self):
119+
for shape in self.shapes:
120+
inputs = torch.randn(batch_size, *shape)
121+
transform, mask = create_coupling_transform(
122+
coupling.UMNNCouplingTransform, shape, integrand_net_layers=[50, 50, 50],
123+
cond_size=20,
124+
nb_steps=20,
125+
solver="CC"
126+
)
127+
outputs, logabsdet = transform(inputs)
128+
with self.subTest(shape=shape):
129+
self.assert_tensor_is_good(outputs, [batch_size] + shape)
130+
self.assert_tensor_is_good(logabsdet, [batch_size])
131+
self.assertEqual(outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...])
132+
133+
def test_inverse(self):
134+
for shape in self.shapes:
135+
inputs = torch.randn(batch_size, *shape)
136+
transform, mask = create_coupling_transform(
137+
coupling.UMNNCouplingTransform, shape, integrand_net_layers=[50, 50, 50],
138+
cond_size=20,
139+
nb_steps=20,
140+
solver="CC"
141+
)
142+
outputs, logabsdet = transform(inputs)
143+
with self.subTest(shape=shape):
144+
self.assert_tensor_is_good(outputs, [batch_size] + shape)
145+
self.assert_tensor_is_good(logabsdet, [batch_size])
146+
self.assertEqual(outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...])
147+
148+
def test_forward_inverse_are_consistent(self):
149+
self.eps = 1e-6
150+
for shape in self.shapes:
151+
inputs = torch.randn(batch_size, *shape)
152+
transform, mask = create_coupling_transform(
153+
coupling.UMNNCouplingTransform, shape, integrand_net_layers=[50, 50, 50],
154+
cond_size=20,
155+
nb_steps=20,
156+
solver="CC"
157+
)
158+
with self.subTest(shape=shape):
159+
self.assert_forward_inverse_are_consistent(transform, inputs)
160+
161+
115162
class PiecewiseCouplingTransformTest(TransformTest):
116163
classes = [
117164
coupling.PiecewiseLinearCouplingTransform,

0 commit comments

Comments
 (0)