Skip to content

Commit 6a3bece

Browse files
committed
volume-preserving co-layers
1 parent ae1dc6a commit 6a3bece

File tree

9 files changed

+330
-14
lines changed

9 files changed

+330
-14
lines changed

bgflow/bg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def sample(
9797
results = list(x)
9898

9999
if with_latent:
100-
results.append(*z)
100+
results.append(z)
101101
if with_dlogp:
102102
results.append(dlogp)
103103
if with_energy:

bgflow/factory/conditioner_factory.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11

2+
from typing import Mapping
23
import torch
34
import bgflow as bg
45
from ..nn.periodic import WrapPeriodic
@@ -17,7 +18,7 @@ def make_conditioners(
1718
transformer_kwargs={},
1819
conditioner_type="dense",
1920
**kwargs
20-
):
21+
) -> Mapping[str, torch.nn.Module]:
2122
"""Create coupling layer conditioners for a given transformer type,
2223
taking care of circular and non-circular tensors.
2324
@@ -43,7 +44,7 @@ def make_conditioners(
4344
4445
Returns
4546
-------
46-
transformer : bg.Transformer
47+
conditioners : Mapping[str, torch.nn.Module]
4748
"""
4849
net_factory = CONDITIONER_FACTORIES[conditioner_type]
4950
dim_out_factory = CONDITIONER_OUT_DIMS[transformer_type]

bgflow/factory/generator_builder.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""High-level Builder API for Boltzmann generators."""
22

3+
import contextlib
4+
import copy
35
import warnings
46
from typing import Mapping, Sequence
57

@@ -9,7 +11,8 @@
911
from ..nn.flow.sequential import SequentialFlow
1012
from ..nn.flow.coupling import SetConstantFlow
1113
from ..nn.flow.transformer.spline import ConditionalSplineTransformer
12-
from ..nn.flow.coupling import CouplingFlow, SplitFlow, WrapFlow, MergeFlow
14+
from ..nn.flow.transformer.affine import AffineTransformer
15+
from ..nn.flow.coupling import CouplingFlow, SplitFlow, WrapFlow, MergeFlow, VolumePreservingWrapFlow
1316
from ..nn.flow.crd_transform.ic import GlobalInternalCoordinateTransformation
1417
from ..nn.flow.inverted import InverseFlow
1518
from ..nn.flow.cdf import CDFTransform
@@ -20,7 +23,8 @@
2023
from ..distribution.product import ProductDistribution, ProductEnergy
2124
from ..bg import BoltzmannGenerator
2225
from .tensor_info import (
23-
TensorInfo, BONDS, ANGLES, TORSIONS, FIXED, ORIGIN, ROTATION, AUGMENTED, TARGET
26+
TensorInfo, BONDS, ANGLES, TORSIONS, FIXED, ORIGIN, ROTATION, AUGMENTED, TARGET,
27+
ShapeDictionary
2428
)
2529
from .conditioner_factory import make_conditioners
2630
from .transformer_factory import make_transformer
@@ -509,6 +513,94 @@ def add_constrain_chirality(self, halpha_torsion_indices, right_handed=False, to
509513
affine = TorchTransform(torch.distributions.AffineTransform(loc=loc, scale=scale), 1)
510514
return self.add_layer(affine, what=(torsions, ))
511515

516+
@contextlib.contextmanager
517+
def volume_preserving_block(
518+
self,
519+
volume_sink: TensorInfo,
520+
condition_on_dlogp: bool = True,
521+
exclude_inputs_from_conditioner: Sequence[TensorInfo] = tuple(),
522+
exclude_outputs_from_conditioner: Sequence[TensorInfo] = tuple(),
523+
**conditioner_kwargs
524+
):
525+
"""Context manager for volume-preserving co-transforms.
526+
A volume-preserving block can contain arbitrary (primary) transforms.
527+
Every volume change (`dlogp != 0`) in this block will be sucked up
528+
by a volume sink tensor so that the total `dlogp` vanishes.
529+
530+
Parameters
531+
----------
532+
volume_sink
533+
The field that acts as a volume sink.
534+
condition_on_dlogp
535+
Whether to condition the affine transform on dlogp of the primary transform
536+
exclude_inputs_from_conditioner
537+
Input tensors that are not passed to the conditioner of the affine co-layer.
538+
exclude_outputs_from_conditioner
539+
Output tensors that are not passed to the conditioner of the affine co-layer.
540+
541+
Notes
542+
-----
543+
It is paramount that the volume sink field is not used by any transform in the block.
544+
545+
Examples
546+
--------
547+
>>> from bgflow import
548+
>>> builder = BoltzmannGeneratorBuilder(...)
549+
>>> with builder.volume_preserving_block(volume_sink=AUGMENTED):
550+
>>> builder.add_condition(BONDS, on=(ANGLES, TORSIONS))
551+
552+
No matter the transform used in the coupling layer, this block will
553+
have vanishing `dlogp` in total.
554+
"""
555+
previous_layer = len(self.layers)
556+
input_shape_dict = copy.copy(self.current_dims)
557+
volume_sink_index_before = self.current_dims.index(volume_sink)
558+
yield
559+
# wrap layers that have been added in context
560+
volume_sink_index_after = self.current_dims.index(volume_sink)
561+
wrapped_flow = SequentialFlow(self.layers[previous_layer:])
562+
self.layers = self.layers[:previous_layer]
563+
564+
# make conditioner inputs
565+
cond_indices = []
566+
cond_names = []
567+
coflow_input_shapes = ShapeDictionary()
568+
569+
dlogp_info = TensorInfo("dlogp", is_circular=False)
570+
coflow_input_shapes[dlogp_info] = (1,)
571+
if condition_on_dlogp:
572+
cond_indices.append(0)
573+
cond_names.append(dlogp_info)
574+
for i, (info, shape) in enumerate(input_shape_dict.items(), start=1):
575+
coflow_input_shapes[info] = shape
576+
if info not in (*exclude_inputs_from_conditioner, volume_sink):
577+
cond_indices.append(i)
578+
cond_names.append(info)
579+
for i, (info, shape) in enumerate(self.current_dims.items(), start=1+len(input_shape_dict)):
580+
info_out = info._replace(name=info.name+"_out")
581+
coflow_input_shapes[info_out] = shape
582+
if info not in (*exclude_outputs_from_conditioner, volume_sink):
583+
cond_indices.append(i)
584+
cond_names.append(info_out)
585+
586+
affine_conditioners = make_conditioners(
587+
transformer_type=AffineTransformer,
588+
what=_tuple(volume_sink),
589+
on=cond_names,
590+
shape_info=coflow_input_shapes,
591+
**conditioner_kwargs
592+
)
593+
affine_conditioners = {name: net.to(**self.ctx) for name, net in affine_conditioners.items()}
594+
volume_preserver = VolumePreservingWrapFlow(
595+
flow=wrapped_flow,
596+
volume_sink_index=volume_sink_index_before,
597+
out_volume_sink_index=volume_sink_index_after,
598+
cond_indices=cond_indices,
599+
**affine_conditioners
600+
)
601+
602+
self.add_layer(volume_preserver)
603+
512604
def _add_to_param_groups(self, parameters, param_groups):
513605
parameters = list(parameters)
514606
for group in param_groups:

bgflow/nn/flow/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
2+
13
import torch
24

35

@@ -31,3 +33,4 @@ def forward(self, *xs, inverse=False, **kwargs):
3133
return self._inverse(*xs, **kwargs)
3234
else:
3335
return self._forward(*xs, **kwargs)
36+

bgflow/nn/flow/coupling.py

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66

77
from .base import Flow
88
from .inverted import InverseFlow
9+
from .transformer.affine import AffineTransformer
910

10-
__all__ = ["SplitFlow", "MergeFlow", "SwapFlow", "CouplingFlow", "WrapFlow", "SetConstantFlow"]
11+
__all__ = [
12+
"SplitFlow", "MergeFlow", "SwapFlow", "CouplingFlow",
13+
"WrapFlow", "SetConstantFlow", "VolumePreservingWrapFlow"
14+
]
1115

1216

1317
class SplitFlow(Flow):
@@ -207,6 +211,7 @@ def _forward(self, *xs, **kwargs):
207211
inp = (xs[i] for i in self._indices)
208212
output = [xs[i] for i in range(len(xs)) if i not in self._indices]
209213
*yi, dlogp = self._flow(*inp, **kwargs)
214+
assert len(yi) == len(self._out_indices)
210215
for i in self._argsort_out_indices:
211216
index = self._out_indices[i]
212217
output.insert(index, yi[i])
@@ -216,11 +221,109 @@ def _inverse(self, *xs, **kwargs):
216221
inp = (xs[i] for i in self._out_indices)
217222
output = [xs[i] for i in range(len(xs)) if i not in self._out_indices]
218223
*yi, dlogp = self._flow(*inp, inverse=True, **kwargs)
224+
assert len(yi) == len(self._indices)
219225
for i in self._argsort_indices:
220226
index = self._indices[i]
221227
output.insert(index, yi[i])
222228
return (*tuple(output), dlogp)
223229

230+
def output_index(self, input_index):
231+
"""Output index of a non-transformed input."""
232+
if input_index in self._indices:
233+
raise ValueError("output_index is only defined for non-transformed inputs")
234+
n_flow_non_inputs_before_sink = sum(i not in self._indices for i in range(input_index))
235+
output_index = n_flow_non_inputs_before_sink
236+
for i_out in sorted(self._out_indices):
237+
if i_out <= output_index:
238+
# "insert before"
239+
output_index += 1
240+
else:
241+
# everything else is inserted after
242+
break
243+
return output_index
244+
245+
246+
class VolumePreservingWrapFlow(Flow):
247+
def __init__(
248+
self,
249+
flow: Flow,
250+
volume_sink_index: int,
251+
out_volume_sink_index: int,
252+
cond_indices: Sequence[int],
253+
shift_transformation: torch.nn.Module = None,
254+
scale_transformation: torch.nn.Module = None
255+
):
256+
"""Volume-preserving wrap layer.
257+
258+
This layer operates on two or more input tensors.
259+
260+
One of these tensors (as indexed by `volume_sink_index` and `out_volume_sink_index`)
261+
acts as a volume sink, while the others are transformed by a flow.
262+
Concretely, after applying the flow, an affine layer is applied to the volume sink
263+
in such a way that the volume change of this affine "co-transform" (`co_dlogp`) counteracts
264+
the volume change (`dlogp`) of the primary flow, `dlogp + co_dlogp = 0`.
265+
266+
The parameters of the co-transform (shift and scale)
267+
can be conditioned on dlogp as well as the inputs and outputs
268+
of the primary flow.
269+
270+
It is important that the primary transform does not use the volume sink in any way,
271+
neither transform it nor condition on it.
272+
273+
Parameters
274+
----------
275+
flow
276+
The primary transform.
277+
volume_sink_index
278+
Input index of the volume sink tensor in the forward pass.
279+
out_volume_sink_index
280+
Input index of the volume sink tensor in the inverse pass.
281+
cond_indices : Sequence[int]
282+
This is a bit tricky. These indices refer to elements of the list
283+
`[dlogp, *inputs, *outputs]`.
284+
shift_transformation : torch.nn.Module, optional
285+
scale_transformation : torch.nn.Module, optional
286+
"""
287+
super().__init__()
288+
self.flow = flow
289+
self.volume_sink_index = volume_sink_index
290+
self.out_volume_sink_index = out_volume_sink_index
291+
co_transform = AffineTransformer(
292+
shift_transformation=shift_transformation,
293+
scale_transformation=scale_transformation,
294+
preserve_volume=True,
295+
)
296+
self.co_flow = CouplingFlow(
297+
transformer=co_transform,
298+
transformed_indices=(1 + self.volume_sink_index, ),
299+
cond_indices=cond_indices,
300+
cat_dim=-1
301+
)
302+
assert all(i != 1 + self.volume_sink_index for i in cond_indices)
303+
304+
def _forward(self, *xs, **kwargs):
305+
*ys, dlogp = self.flow.forward(*xs, **kwargs)
306+
co_out, co_dlogp = self._apply_coflow(dlogp, xs, ys, inverse=False)
307+
ys[self.out_volume_sink_index] = co_out
308+
return (*ys, dlogp + co_dlogp)
309+
310+
def _inverse(self, *ys, **kwargs):
311+
*xs, dlogp = self.flow.forward(*ys, inverse=True, **kwargs)
312+
co_out, co_dlogp = self._apply_coflow(forward_dlogp=-dlogp, xs=xs, ys=ys, inverse=True)
313+
xs[self.volume_sink_index] = co_out
314+
return (*xs, dlogp + co_dlogp)
315+
316+
def _apply_coflow(self, forward_dlogp, xs, ys, inverse):
317+
assert torch.allclose(xs[self.volume_sink_index], ys[self.out_volume_sink_index])
318+
coflow_in = [
319+
forward_dlogp,
320+
*[x for i, x in enumerate(xs)],
321+
*[y for i, y in enumerate(ys)]
322+
]
323+
target_dlogp = forward_dlogp if inverse else -forward_dlogp
324+
*co_out, co_dlogp = self.co_flow.forward(*coflow_in, target_dlogp=target_dlogp, inverse=inverse)
325+
return co_out[1 + self.volume_sink_index], co_dlogp
326+
224327

225328
class SetConstantFlow(Flow):
226329
"""A flow that sets some inputs constant in the forward direction and removes them in the inverse.

bgflow/nn/flow/transformer/affine.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
2+
from typing import Union
3+
4+
import warnings
15
import torch
6+
import numpy as np
27

38
from .base import Transformer
49

@@ -32,7 +37,7 @@ def __init__(
3237
self._preserve_volume = preserve_volume
3338
self._is_circular = is_circular
3439

35-
def _get_mu_and_log_sigma(self, x, y, *cond):
40+
def _get_mu_and_log_sigma(self, x, y, *cond, target_dlogp: Union[float, torch.Tensor] = None):
3641
if self._shift_transformation is not None:
3742
mu = self._shift_transformation(x, *cond)
3843
else:
@@ -42,13 +47,22 @@ def _get_mu_and_log_sigma(self, x, y, *cond):
4247
log_sigma = torch.tanh(self._scale_transformation(x, *cond))
4348
log_sigma = log_sigma * alpha
4449
if self._preserve_volume:
45-
log_sigma = log_sigma - log_sigma.mean(dim=-1, keepdim=True)
50+
target_dlogp = 0.0 if target_dlogp is None else target_dlogp
51+
target_scale = target_dlogp / np.prod(log_sigma[0].shape)
52+
log_sigma = (
53+
log_sigma
54+
- log_sigma.mean(dim=-1, keepdim=True)
55+
+ target_scale * torch.ones_like(log_sigma)
56+
)
57+
else:
58+
if target_dlogp is not None:
59+
warnings.warn("target_dlogp is only effective is self.preserve_volume is enabled.")
4660
else:
4761
log_sigma = torch.zeros_like(y).to(x)
4862
return mu, log_sigma
4963

50-
def _forward(self, x, y, *cond, **kwargs):
51-
mu, log_sigma = self._get_mu_and_log_sigma(x, y, *cond)
64+
def _forward(self, x, y, *cond, target_dlogp=None, **kwargs):
65+
mu, log_sigma = self._get_mu_and_log_sigma(x, y, *cond, target_dlogp=target_dlogp)
5266
assert mu.shape[-1] == y.shape[-1]
5367
assert log_sigma.shape[-1] == y.shape[-1]
5468
sigma = torch.exp(log_sigma)
@@ -58,8 +72,8 @@ def _forward(self, x, y, *cond, **kwargs):
5872
y = y % 1.0
5973
return y, dlogp
6074

61-
def _inverse(self, x, y, *cond, **kwargs):
62-
mu, log_sigma = self._get_mu_and_log_sigma(x, y, *cond)
75+
def _inverse(self, x, y, *cond, target_dlogp=None, **kwargs):
76+
mu, log_sigma = self._get_mu_and_log_sigma(x, y, *cond, target_dlogp=None if target_dlogp is None else -target_dlogp)
6377
assert mu.shape[-1] == y.shape[-1]
6478
assert log_sigma.shape[-1] == y.shape[-1]
6579
sigma_inv = torch.exp(-log_sigma)

tests/factory/test_generator_builder.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,3 +225,33 @@ def test_constrain_chirality(ala2, ctx):
225225
b, a, t, *_ = crd_transform.forward(samples)
226226
assert torch.all(t[:, chiral_torsions] >= 0.5)
227227
assert torch.all(t[:, chiral_torsions] <= 1.0)
228+
229+
230+
def test_volume_preserving_context(ctx):
231+
shape_info = ShapeDictionary()
232+
shape_info[BONDS] = (10, )
233+
shape_info[ANGLES] = (20, )
234+
builder = BoltzmannGeneratorBuilder(
235+
shape_info,
236+
**ctx
237+
)
238+
builder.targets[BONDS] = NormalDistribution(10, torch.zeros(10, **ctx))
239+
builder.targets[ANGLES] = NormalDistribution(20, torch.zeros(20, **ctx))
240+
# transform some fields
241+
with builder.volume_preserving_block(volume_sink=ANGLES):
242+
builder.add_layer(
243+
CDFTransform(
244+
TruncatedNormalDistribution(
245+
torch.zeros(10, **ctx),
246+
lower_bound=-torch.tensor(np.infty, **ctx)
247+
),
248+
),
249+
what=[BONDS],
250+
inverse=True,
251+
param_groups=("group1", )
252+
)
253+
generator = builder.build_generator()
254+
results = generator.sample(10, with_latent=True, with_dlogp=True, with_energy=True)
255+
x, z, dlogp, energy = results[:2], results[2], results[3], results[4]
256+
assert torch.allclose(dlogp, torch.zeros_like(dlogp), atol=1e-5)
257+
assert torch.allclose(generator.energy(*x), generator.prior.energy(*z), atol=1e-5)

0 commit comments

Comments
 (0)