Skip to content

Commit 8eef708

Browse files
committed
Merge branch 'main' into vp-flow
2 parents 8f436c8 + fcc50de commit 8eef708

File tree

25 files changed

+2648
-63
lines changed

25 files changed

+2648
-63
lines changed

.github/workflows/CI.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ jobs:
4141
- name: Install dependencies
4242
run: |
4343
python -m pip install --upgrade pip
44-
pip install pytest "${{ matrix.cfg.torch-version }}" numpy nflows torchdiffeq einops
44+
pip install pytest "${{ matrix.cfg.torch-version }}" numpy nflows torchdiffeq einops netCDF4
4545
- name: Install package
4646
run: |
4747
python setup.py install

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ pytest
139139
* [OpenMM](https://github.com/openmm/openmm) (for molecular examples)
140140
* [torchdiffeq](https://github.com/rtqichen/torchdiffeq) (for neural ODEs)
141141
* [ANODE](https://github.com/amirgholami/anode) (for neural ODEs)
142+
* [netCDF4](https://unidata.github.io/netcdf4-python/) (for the `ReplayBufferReporter`)
142143
* [jax](https://github.com/google/jax) (for smooth flows / implicit backprop)
143144
* [jax2torch](https://github.com/lucidrains/jax2torch) (for smooth flows / implicit backprop)
144145

bgflow/bg.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
from .distribution.energy import Energy
44
from .distribution.sampling import Sampler
5+
from .utils.types import pack_tensor_in_tuple
6+
7+
__all__ = [
8+
"BoltzmannGenerator", "unnormalized_kl_div", "unormalized_nll",
9+
"sampling_efficiency", "effective_sample_size", "log_weights",
10+
"log_weights_given_latent"
11+
]
512

613
__all__ = [
714
"BoltzmannGenerator", "unnormalized_kl_div", "unormalized_nll",
@@ -12,8 +19,7 @@
1219

1320
def unnormalized_kl_div(prior, flow, target, n_samples, temperature=1.0):
1421
z = prior.sample(n_samples, temperature=temperature)
15-
if isinstance(z, torch.Tensor):
16-
z = (z,)
22+
z = pack_tensor_in_tuple(z)
1723
*x, dlogp = flow(*z, temperature=temperature)
1824
return target.energy(*x, temperature=temperature) - dlogp
1925

@@ -31,10 +37,8 @@ def log_weights(*x, prior, flow, target, temperature=1.0, normalize=True):
3137

3238

3339
def log_weights_given_latent(x, z, dlogp, prior, target, temperature=1.0, normalize=True):
34-
if isinstance(x, torch.Tensor):
35-
x = (x,)
36-
if isinstance(z, torch.Tensor):
37-
z = (z,)
40+
x = pack_tensor_in_tuple(x)
41+
z = pack_tensor_in_tuple(z)
3842
logw = (
3943
prior.energy(*z, temperature=temperature)
4044
+ dlogp
@@ -94,8 +98,7 @@ def sample(
9498
with_weights=False,
9599
):
96100
z = self._prior.sample(n_samples, temperature=temperature)
97-
if isinstance(z, torch.Tensor):
98-
z = (z,)
101+
z = pack_tensor_in_tuple(z)
99102
*x, dlogp = self._flow(*z, temperature=temperature)
100103
results = list(x)
101104

bgflow/distribution/energy/double_well.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def __init__(self, dim, a=0, b=-4.0, c=1.0):
1414
self._c = c
1515

1616
def _energy(self, x):
17-
d = x[:, [0]]
18-
v = x[:, 1:]
17+
d = x[..., [0]]
18+
v = x[..., 1:]
1919
e1 = self._a * d + self._b * d.pow(2) + self._c * d.pow(4)
2020
e2 = 0.5 * v.pow(2).sum(dim=-1, keepdim=True)
2121
return e1 + e2
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
11
from .base import *
22
from .mcmc import *
3-
from .dataset import *
3+
from .dataset import *
4+
from .buffer import *
5+
from .iterative import *
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
"""Helper classes and functions for iterative samplers."""
2+
3+
import torch
4+
5+
6+
__all__ = ["AbstractSamplerState", "default_set_samples_hook", "default_extract_sample_hook"]
7+
8+
9+
class AbstractSamplerState:
10+
"""Defines the interface for implementations of the internal state of iterative samplers."""
11+
12+
def as_dict(self):
13+
"""Return a dictionary representing this instance. The dictionary has to define the
14+
keys that are used within `SamplerStep`s of an `IterativeSampler`, such as "samples", "energies", ...
15+
"""
16+
raise NotImplementedError()
17+
18+
def _replace(self, **kwargs):
19+
"""Return a new object with changed fields.
20+
This function has to support all the keys that are used
21+
within `SamplerStep`s of an `IterativeSampler` as well as the keys "energies_up_to_date" and
22+
"forces_up_to_date"
23+
"""
24+
raise NotImplementedError()
25+
26+
def evaluate_energy_force(self, energy_model, evaluate_energies=True, evaluate_forces=True):
27+
"""Return a new state with updated energies/forces."""
28+
state = self.as_dict()
29+
evaluate_energies = evaluate_energies and not state["energies_up_to_date"]
30+
energies = energy_model.energy(*state["samples"])[..., 0] if evaluate_energies else state["energies"]
31+
32+
evaluate_forces = evaluate_forces and not state["forces_up_to_date"]
33+
forces = energy_model.force(*state["samples"]) if evaluate_forces else state["forces"]
34+
return self.replace(energies=energies, forces=forces)
35+
36+
def replace(self, **kwargs):
37+
"""Return a new state with updated fields."""
38+
39+
# keep track of energies and forces
40+
state_dict = self.as_dict()
41+
if "energies" in kwargs:
42+
kwargs = {**kwargs, "energies_up_to_date": True}
43+
elif "samples" in kwargs:
44+
kwargs = {**kwargs, "energies_up_to_date": False}
45+
if "forces" in kwargs:
46+
kwargs = {**kwargs, "forces_up_to_date": True}
47+
elif "samples" in kwargs:
48+
kwargs = {**kwargs, "forces_up_to_date": False}
49+
50+
# map to primary unit cell
51+
box_vectors = None
52+
if "box_vectors" in kwargs:
53+
box_vectors = kwargs["box_vectors"]
54+
elif "box_vectors" in state_dict:
55+
box_vectors = state_dict["box_vectors"]
56+
if "samples" in kwargs and box_vectors is not None:
57+
kwargs = {
58+
**kwargs,
59+
"samples": tuple(
60+
_map_to_primary_cell(x, cell)
61+
for x, cell in zip(kwargs["samples"], box_vectors)
62+
)
63+
}
64+
return self._replace(**kwargs)
65+
66+
67+
def default_set_samples_hook(x):
68+
"""by default, use samples as is"""
69+
return x
70+
71+
72+
def default_extract_sample_hook(state: AbstractSamplerState):
73+
"""Default extraction of samples from a SamplerState."""
74+
return state.as_dict()["samples"]
75+
76+
77+
def _bmv(m, bv):
78+
"""Batched matrix-vector multiply."""
79+
return torch.einsum("ij,...j->...i", m, bv)
80+
81+
82+
def _map_to_primary_cell(x, cell):
83+
"""Map coordinates to the primary unit cell of a periodic lattice.
84+
85+
Parameters
86+
----------
87+
x : torch.Tensor
88+
n-dimensional coordinates of shape (..., n), where n is the spatial dimension and ... denote an
89+
arbitrary number of batch dimensions.
90+
cell : torch.Tensor
91+
Lattice vectors (column-wise). Has to be upper triangular.
92+
"""
93+
if cell is None:
94+
return x
95+
n = _bmv(torch.inverse(cell), x)
96+
n = torch.floor(n)
97+
return x - _bmv(cell, n)
Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,27 @@
1-
import torch
21

2+
from typing import Tuple
3+
import torch
4+
from ...utils.types import unpack_tensor_tuple, pack_tensor_in_list
35

46
__all__ = ["Sampler"]
57

68

79
class Sampler(torch.nn.Module):
8-
9-
def __init__(self, **kwargs):
10+
"""Abstract base class for samplers.
11+
12+
Parameters
13+
----------
14+
return_hook : Callable, optional
15+
A function to postprocess the samples. This can (for example) be used to
16+
only return samples at a selected thermodynamic state of a replica exchange sampler
17+
or to combine the batch and sample dimension.
18+
The function takes a list of tensors and should return a list of tensors.
19+
Each tensor contains a batch of samples.
20+
"""
21+
22+
def __init__(self, return_hook=lambda x: x, **kwargs):
1023
super().__init__(**kwargs)
24+
self.return_hook = return_hook
1125

1226
def _sample_with_temperature(self, n_samples, temperature, *args, **kwargs):
1327
raise NotImplementedError()
@@ -16,7 +30,39 @@ def _sample(self, n_samples, *args, **kwargs):
1630
raise NotImplementedError()
1731

1832
def sample(self, n_samples, temperature=1.0, *args, **kwargs):
33+
"""Create a number of samples.
34+
35+
Parameters
36+
----------
37+
n_samples : int
38+
The number of samples to be created.
39+
temperature : float, optional
40+
The relative temperature at which to create samples.
41+
Only available for sampler that implement `_sample_with_temperature`.
42+
43+
Returns
44+
-------
45+
samples : Union[torch.Tensor, Tuple[torch.Tensor, ...]]
46+
If this sampler reflects a joint distribution of multiple tensors,
47+
it returns a tuple of tensors, each of which have length n_samples.
48+
Otherwise it returns a single tensor of length n_samples.
49+
"""
1950
if isinstance(temperature, float) and temperature == 1.0:
20-
return self._sample(n_samples, *args, **kwargs)
51+
samples = self._sample(n_samples, *args, **kwargs)
2152
else:
22-
return self._sample_with_temperature(n_samples, temperature, *args, **kwargs)
53+
samples = self._sample_with_temperature(n_samples, temperature, *args, **kwargs)
54+
samples = pack_tensor_in_list(samples)
55+
return unpack_tensor_tuple(self.return_hook(samples))
56+
57+
def sample_to_cpu(self, n_samples, batch_size=64, *args, **kwargs):
58+
"""A utility method for creating many samples that might not fit into GPU memory."""
59+
with torch.no_grad():
60+
samples = self.sample(min(n_samples, batch_size), *args, **kwargs)
61+
samples = pack_tensor_in_list(samples)
62+
samples = [tensor.detach().cpu() for tensor in samples]
63+
while len(samples[0]) < n_samples:
64+
new_samples = self.sample(min(n_samples-len(samples[0]), batch_size), *args, **kwargs)
65+
new_samples = pack_tensor_in_list(new_samples)
66+
for i, new in enumerate(new_samples):
67+
samples[i] = torch.cat([samples[i], new.detach().cpu()], dim=0)
68+
return unpack_tensor_tuple(samples)

0 commit comments

Comments
 (0)