Skip to content

Commit 639c3a7

Browse files
authored
Merge pull request #29 from AWehenkel/UMNN
UMNNs implementation
2 parents 75048ff + 616c3ee commit 639c3a7

File tree

8 files changed

+287
-1
lines changed

8 files changed

+287
-1
lines changed

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ dependencies:
2121
- pip:
2222
- torchtestcase
2323
- -e . # install package in development mode
24+
- umnn
2425
- pytest
2526
- python
2627
- pytorch
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
import torch
2+
from UMNN import NeuralIntegral, ParallelNeuralIntegral
3+
import torch.nn as nn
4+
5+
6+
def _flatten(sequence):
7+
flat = [p.contiguous().view(-1) for p in sequence]
8+
return torch.cat(flat) if len(flat) > 0 else torch.tensor([])
9+
10+
11+
class ELUPlus(nn.Module):
12+
def __init__(self):
13+
super().__init__()
14+
self.elu = nn.ELU()
15+
16+
def forward(self, x):
17+
return self.elu(x) + 1.
18+
19+
20+
class IntegrandNet(nn.Module):
21+
def __init__(self, hidden, cond_in):
22+
super(IntegrandNet, self).__init__()
23+
l1 = [1 + cond_in] + hidden
24+
l2 = hidden + [1]
25+
layers = []
26+
for h1, h2 in zip(l1, l2):
27+
layers += [nn.Linear(h1, h2), nn.ReLU()]
28+
layers.pop()
29+
layers.append(ELUPlus())
30+
self.net = nn.Sequential(*layers)
31+
32+
def forward(self, x, h):
33+
nb_batch, in_d = x.shape
34+
x = torch.cat((x, h), 1)
35+
x_he = x.view(nb_batch, -1, in_d).transpose(1, 2).contiguous().view(nb_batch * in_d, -1)
36+
y = self.net(x_he).view(nb_batch, -1)
37+
return y
38+
39+
40+
class MonotonicNormalizer(nn.Module):
41+
def __init__(self, integrand_net, cond_size, nb_steps=20, solver="CC"):
42+
super(MonotonicNormalizer, self).__init__()
43+
if type(integrand_net) is list:
44+
self.integrand_net = IntegrandNet(integrand_net, cond_size)
45+
else:
46+
self.integrand_net = integrand_net
47+
self.solver = solver
48+
self.nb_steps = nb_steps
49+
50+
def forward(self, x, h, context=None):
51+
x0 = torch.zeros(x.shape).to(x.device)
52+
xT = x
53+
z0 = h[:, :, 0]
54+
h = h.permute(0, 2, 1).contiguous().view(x.shape[0], -1)
55+
if self.solver == "CC":
56+
z = NeuralIntegral.apply(x0, xT, self.integrand_net, _flatten(self.integrand_net.parameters()),
57+
h, self.nb_steps) + z0
58+
elif self.solver == "CCParallel":
59+
z = ParallelNeuralIntegral.apply(x0, xT, self.integrand_net,
60+
_flatten(self.integrand_net.parameters()),
61+
h, self.nb_steps) + z0
62+
else:
63+
return None
64+
return z, self.integrand_net(x, h)
65+
66+
def inverse_transform(self, z, h, context=None):
67+
# Old inversion by binary search
68+
x_max = torch.ones_like(z) * 20
69+
x_min = -torch.ones_like(z) * 20
70+
z_max, _ = self.forward(x_max, h, context)
71+
z_min, _ = self.forward(x_min, h, context)
72+
for i in range(25):
73+
x_middle = (x_max + x_min) / 2
74+
z_middle, _ = self.forward(x_middle, h, context)
75+
left = (z_middle > z).float()
76+
right = 1 - left
77+
x_max = left * x_middle + right * x_max
78+
x_min = right * x_middle + left * x_min
79+
z_max = left * z_middle + right * z_max
80+
z_min = right * z_middle + left * z_min
81+
return (x_max + x_min) / 2
82+
83+

nflows/transforms/UMNN/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from nflows.transforms.UMNN.MonotonicNormalizer import MonotonicNormalizer, IntegrandNet

nflows/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
MaskedPiecewiseLinearAutoregressiveTransform,
55
MaskedPiecewiseQuadraticAutoregressiveTransform,
66
MaskedPiecewiseRationalQuadraticAutoregressiveTransform,
7+
MaskedUMNNAutoregressiveTransform,
78
)
89
from nflows.transforms.base import (
910
CompositeTransform,
@@ -21,6 +22,7 @@
2122
PiecewiseLinearCouplingTransform,
2223
PiecewiseQuadraticCouplingTransform,
2324
PiecewiseRationalQuadraticCouplingTransform,
25+
UMNNCouplingTransform,
2426
)
2527
from nflows.transforms.linear import NaiveLinear
2628
from nflows.transforms.lu import LULinear

nflows/transforms/autoregressive.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
unconstrained_rational_quadratic_spline,
1919
)
2020
from nflows.utils import torchutils
21+
from nflows.transforms.UMNN import *
2122

2223

2324
class AutoregressiveTransform(Transform):
@@ -127,6 +128,71 @@ def _unconstrained_scale_and_shift(self, autoregressive_params):
127128
return unconstrained_scale, shift
128129

129130

131+
class MaskedUMNNAutoregressiveTransform(AutoregressiveTransform):
132+
"""An unconstrained monotonic neural networks autoregressive layer that transforms the variables.
133+
134+
Reference:
135+
> A. Wehenkel and G. Louppe, Unconstrained Monotonic Neural Networks, NeurIPS2019.
136+
137+
---- Specific arguments ----
138+
integrand_net_layers: the layers dimension to put in the integrand network.
139+
cond_size: The embedding size for the conditioning factors.
140+
nb_steps: The number of integration steps.
141+
solver: The quadrature algorithm - CC or CCParallel. Both implements Clenshaw-Curtis quadrature with
142+
Leibniz rule for backward computation. CCParallel pass all the evaluation points (nb_steps) at once, it is faster
143+
but requires more memory.
144+
"""
145+
def __init__(
146+
self,
147+
features,
148+
hidden_features,
149+
context_features=None,
150+
num_blocks=2,
151+
use_residual_blocks=True,
152+
random_mask=False,
153+
activation=F.relu,
154+
dropout_probability=0.0,
155+
use_batch_norm=False,
156+
integrand_net_layers=[50, 50, 50],
157+
cond_size=20,
158+
nb_steps=20,
159+
solver="CCParallel",
160+
):
161+
self.features = features
162+
self.cond_size = cond_size
163+
made = made_module.MADE(
164+
features=features,
165+
hidden_features=hidden_features,
166+
context_features=context_features,
167+
num_blocks=num_blocks,
168+
output_multiplier=self._output_dim_multiplier(),
169+
use_residual_blocks=use_residual_blocks,
170+
random_mask=random_mask,
171+
activation=activation,
172+
dropout_probability=dropout_probability,
173+
use_batch_norm=use_batch_norm,
174+
)
175+
self._epsilon = 1e-3
176+
super().__init__(made)
177+
self.transformer = MonotonicNormalizer(integrand_net_layers, cond_size, nb_steps, solver)
178+
179+
180+
def _output_dim_multiplier(self):
181+
return self.cond_size
182+
183+
def _elementwise_forward(self, inputs, autoregressive_params):
184+
z, jac = self.transformer(inputs, autoregressive_params.reshape(inputs.shape[0], inputs.shape[1], -1))
185+
log_det_jac = jac.log().sum(1)
186+
return z, log_det_jac
187+
188+
def _elementwise_inverse(self, inputs, autoregressive_params):
189+
x = self.transformer.inverse_transform(inputs, autoregressive_params.reshape(inputs.shape[0], inputs.shape[1], -1))
190+
z, jac = self.transformer(x, autoregressive_params.reshape(inputs.shape[0], inputs.shape[1], -1))
191+
log_det_jac = -jac.log().sum(1)
192+
return x, log_det_jac
193+
194+
195+
130196
class MaskedPiecewiseLinearAutoregressiveTransform(AutoregressiveTransform):
131197
def __init__(
132198
self,

nflows/transforms/coupling.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
PiecewiseRationalQuadraticCDF,
1414
)
1515
from nflows.utils import torchutils
16+
from nflows.transforms.UMNN import *
1617

1718

1819
class CouplingTransform(Transform):
@@ -140,6 +141,73 @@ def _coupling_transform_inverse(self, inputs, transform_params):
140141
raise NotImplementedError()
141142

142143

144+
class UMNNCouplingTransform(CouplingTransform):
145+
"""An unconstrained monotonic neural networks coupling layer that transforms the variables.
146+
147+
Reference:
148+
> A. Wehenkel and G. Louppe, Unconstrained Monotonic Neural Networks, NeurIPS2019.
149+
150+
---- Specific arguments ----
151+
integrand_net_layers: the layers dimension to put in the integrand network.
152+
cond_size: The embedding size for the conditioning factors.
153+
nb_steps: The number of integration steps.
154+
solver: The quadrature algorithm - CC or CCParallel. Both implements Clenshaw-Curtis quadrature with
155+
Leibniz rule for backward computation. CCParallel pass all the evaluation points (nb_steps) at once, it is faster
156+
but requires more memory.
157+
158+
"""
159+
def __init__(
160+
self,
161+
mask,
162+
transform_net_create_fn,
163+
integrand_net_layers=[50, 50, 50],
164+
cond_size=20,
165+
nb_steps=20,
166+
solver="CCParallel",
167+
apply_unconditional_transform=False
168+
):
169+
170+
if apply_unconditional_transform:
171+
unconditional_transform = lambda features: MonotonicNormalizer(integrand_net_layers, 0, nb_steps, solver)
172+
else:
173+
unconditional_transform = None
174+
self.cond_size = cond_size
175+
super().__init__(
176+
mask,
177+
transform_net_create_fn,
178+
unconditional_transform=unconditional_transform,
179+
)
180+
181+
self.transformer = MonotonicNormalizer(integrand_net_layers, cond_size, nb_steps, solver)
182+
183+
def _transform_dim_multiplier(self):
184+
return self.cond_size
185+
186+
def _coupling_transform_forward(self, inputs, transform_params):
187+
if len(inputs.shape) == 2:
188+
z, jac = self.transformer(inputs, transform_params.reshape(inputs.shape[0], inputs.shape[1], -1))
189+
log_det_jac = jac.log().sum(1)
190+
return z, log_det_jac
191+
else:
192+
B, C, H, W = inputs.shape
193+
z, jac = self.transformer(inputs.permute(0, 2, 3, 1).reshape(-1, inputs.shape[1]), transform_params.permute(0, 2, 3, 1).reshape(-1, 1, transform_params.shape[1]))
194+
log_det_jac = jac.log().reshape(B, -1).sum(1)
195+
return z.reshape(B, H, W, C).permute(0, 3, 1, 2), log_det_jac
196+
197+
def _coupling_transform_inverse(self, inputs, transform_params):
198+
if len(inputs.shape) == 2:
199+
x = self.transformer.inverse_transform(inputs, transform_params.reshape(inputs.shape[0], inputs.shape[1], -1))
200+
z, jac = self.transformer(x, transform_params.reshape(inputs.shape[0], inputs.shape[1], -1))
201+
log_det_jac = -jac.log().sum(1)
202+
return x, log_det_jac
203+
else:
204+
B, C, H, W = inputs.shape
205+
x = self.transformer.inverse_transform(inputs.permute(0, 2, 3, 1).reshape(-1, inputs.shape[1]), transform_params.permute(0, 2, 3, 1).reshape(-1, 1, transform_params.shape[1]))
206+
z, jac = self.transformer(x, transform_params.permute(0, 2, 3, 1).reshape(-1, 1, transform_params.shape[1]))
207+
log_det_jac = -jac.log().reshape(B, -1).sum(1)
208+
return x.reshape(B, H, W, C).permute(0, 3, 1, 2), log_det_jac
209+
210+
143211
class AffineCouplingTransform(CouplingTransform):
144212
"""An affine coupling layer that scales and shifts part of the variables.
145213
@@ -151,7 +219,7 @@ def _transform_dim_multiplier(self):
151219
return 2
152220

153221
def _scale_and_shift(self, transform_params):
154-
unconstrained_scale = transform_params[:, self.num_transform_features :, ...]
222+
unconstrained_scale = transform_params[:, self.num_transform_features:, ...]
155223
shift = transform_params[:, : self.num_transform_features, ...]
156224
# scale = (F.softplus(unconstrained_scale) + 1e-3).clamp(0, 3)
157225
scale = torch.sigmoid(unconstrained_scale + 2) + 1e-3

tests/transforms/autoregressive_test.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,24 @@ def test_forward_inverse_are_consistent(self):
114114
self.assert_forward_inverse_are_consistent(transform, inputs)
115115

116116

117+
class MaskedUMNNAutoregressiveTranformTest(TransformTest):
118+
def test_forward_inverse_are_consistent(self):
119+
batch_size = 10
120+
features = 20
121+
inputs = torch.rand(batch_size, features)
122+
self.eps = 1e-4
123+
124+
transform = autoregressive.MaskedUMNNAutoregressiveTransform(
125+
cond_size=10,
126+
features=features,
127+
hidden_features=30,
128+
num_blocks=5,
129+
use_residual_blocks=True,
130+
)
131+
132+
self.assert_forward_inverse_are_consistent(transform, inputs)
133+
134+
117135
class MaskedPiecewiseCubicAutoregressiveTranformTest(TransformTest):
118136
def test_forward_inverse_are_consistent(self):
119137
batch_size = 10

tests/transforms/coupling_test.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,53 @@ def test_forward_inverse_are_consistent(self):
112112
self.assert_forward_inverse_are_consistent(transform, inputs)
113113

114114

115+
class UMNNTransformTest(TransformTest):
116+
shapes = [[20], [2, 4, 4]]
117+
118+
def test_forward(self):
119+
for shape in self.shapes:
120+
inputs = torch.randn(batch_size, *shape)
121+
transform, mask = create_coupling_transform(
122+
coupling.UMNNCouplingTransform, shape, integrand_net_layers=[50, 50, 50],
123+
cond_size=20,
124+
nb_steps=20,
125+
solver="CC"
126+
)
127+
outputs, logabsdet = transform(inputs)
128+
with self.subTest(shape=shape):
129+
self.assert_tensor_is_good(outputs, [batch_size] + shape)
130+
self.assert_tensor_is_good(logabsdet, [batch_size])
131+
self.assertEqual(outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...])
132+
133+
def test_inverse(self):
134+
for shape in self.shapes:
135+
inputs = torch.randn(batch_size, *shape)
136+
transform, mask = create_coupling_transform(
137+
coupling.UMNNCouplingTransform, shape, integrand_net_layers=[50, 50, 50],
138+
cond_size=20,
139+
nb_steps=20,
140+
solver="CC"
141+
)
142+
outputs, logabsdet = transform(inputs)
143+
with self.subTest(shape=shape):
144+
self.assert_tensor_is_good(outputs, [batch_size] + shape)
145+
self.assert_tensor_is_good(logabsdet, [batch_size])
146+
self.assertEqual(outputs[:, mask <= 0, ...], inputs[:, mask <= 0, ...])
147+
148+
def test_forward_inverse_are_consistent(self):
149+
self.eps = 1e-6
150+
for shape in self.shapes:
151+
inputs = torch.randn(batch_size, *shape)
152+
transform, mask = create_coupling_transform(
153+
coupling.UMNNCouplingTransform, shape, integrand_net_layers=[50, 50, 50],
154+
cond_size=20,
155+
nb_steps=20,
156+
solver="CC"
157+
)
158+
with self.subTest(shape=shape):
159+
self.assert_forward_inverse_are_consistent(transform, inputs)
160+
161+
115162
class PiecewiseCouplingTransformTest(TransformTest):
116163
classes = [
117164
coupling.PiecewiseLinearCouplingTransform,

0 commit comments

Comments
 (0)