Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion bgflow/distribution/energy/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _evaluate_single(
assert not np.isnan(force).any()
except AssertionError as e:
force[np.isnan(force)] = 0.
energy = np.infty
energy = np.inf
if self.err_handling == "warning":
warnings.warn("Found nan in ase force or energy. Returning infinite energy and zero force.")
elif self.err_handling == "error":
Expand Down
47 changes: 31 additions & 16 deletions bgflow/distribution/energy/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,17 +216,32 @@ def force(

class _BridgeEnergyWrapper(torch.autograd.Function):
@staticmethod
def forward(ctx, input, bridge):
energy, force, *_ = bridge.evaluate(input)
ctx.save_for_backward(-force)
return energy
def forward(x, bridge):
energy, force, *_ = bridge.evaluate(x)
return energy, -force

@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output, _0):
neg_force, = ctx.saved_tensors
grad_input = grad_output * neg_force
grad_input = grad_output[:, None] * neg_force
return grad_input, None

@staticmethod
def setup_context(ctx, inputs, output):
x, bridge = inputs
energy, neg_force = output
ctx.save_for_backward(neg_force)

@staticmethod
def vmap(info, in_dims, x , bridge):
x_shape = x.shape
x = x.reshape(-1, x_shape[-1])
energy, force, *_ = bridge.evaluate(x)
energy = energy.reshape(x_shape[:-1])
force = force.reshape(x_shape)
return (energy, -force), (0, 0)



_evaluate_bridge_energy = _BridgeEnergyWrapper.apply

Expand Down Expand Up @@ -308,16 +323,16 @@ def bridge(self):

def _energy(self, batch, no_grads=False):
# check if we have already computed this energy (hash of string representation should be sufficient)
if hash(str(batch)) == self._last_batch:
return self._bridge.last_energies
else:
self._last_batch = hash(str(batch))
return _evaluate_bridge_energy(batch, self._bridge)
# if hash(str(batch)) == self._last_batch:
# return self._bridge.last_energies
# else:
self._last_batch = hash(str(batch))
return _evaluate_bridge_energy(batch, self._bridge)[0]

def force(self, batch, temperature=None):
# check if we have already computed this energy
if hash(str(batch)) == self.last_batch:
return self.bridge.last_forces
else:
self._last_batch = hash(str(batch))
return self._bridge.evaluate(batch)[1]
# if hash(str(batch)) == self.last_batch:
# return self.bridge.last_forces
# else:
self._last_batch = hash(str(batch))
return self._bridge.evaluate(batch)[1]
4 changes: 3 additions & 1 deletion bgflow/distribution/energy/openmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
openmm_system,
openmm_integrator,
platform_name='CPU',
platform_properties=None,
err_handling="warning",
n_workers=mp.cpu_count(),
n_simulation_steps=0
Expand All @@ -50,7 +51,8 @@ def __init__(
from openmm import unit
except ImportError: # fall back to older version < 7.6
from simtk import unit
platform_properties = {'Threads': str(max(1, mp.cpu_count()//n_workers))} if platform_name == "CPU" else {}
if platform_properties is None:
platform_properties = {'Threads': str(max(1, mp.cpu_count()//n_workers))} if platform_name == "CPU" else {}

# Compute all energies in child processes due to a bug in the OpenMM's PME code.
# This might be problematic if an energy has already been computed in the same program on the parent thread,
Expand Down
6 changes: 3 additions & 3 deletions bgflow/distribution/energy/xtb.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ def _evaluate_single(
f"Original exception: {e}"
)
force = np.zeros_like(positions)
energy = np.infty
energy = np.inf
elif self.err_handling == "ignore":
force = np.zeros_like(positions)
energy = np.infty
energy = np.inf
except AssertionError:
force[np.isnan(force)] = 0.
energy = np.infty
energy = np.inf
if self.err_handling in ["error", "warning"]:
warnings.warn("Found nan in xtb force or energy. Returning infinite energy and zero force.")

Expand Down
10 changes: 5 additions & 5 deletions bgflow/distribution/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ class TruncatedNormalDistribution(Energy, Sampler):
Mean of the untruncated normal distribution.
sigma : float or tensor of floats of shape (dim, )
Standard deviation of the untruncated normal distribution.
lower_bound : float, -np.infty, or tensor of floats of shape (dim, )
lower_bound : float, -np.inf, or tensor of floats of shape (dim, )
Lower truncation bound.
upper_bound : float, np.infty, or tensor of floats of shape (dim, )
upper_bound : float, np.inf, or tensor of floats of shape (dim, )
Upper truncation bound.
assert_range : bool
Whether to raise an error when `energy` is called on input that falls out of bounds.
Expand All @@ -123,7 +123,7 @@ def __init__(
mu,
sigma=torch.tensor(1.0),
lower_bound=torch.tensor(0.0),
upper_bound=torch.tensor(np.infty),
upper_bound=torch.tensor(np.inf),
assert_range=True,
sampling_method="icdf",
is_learnable=False
Expand Down Expand Up @@ -208,8 +208,8 @@ def _energy(self, x):
if (x < self._lower_bound).any() or (x > self._upper_bound).any():
raise ValueError("input out of bounds")
else:
energies[x < self._lower_bound] = np.infty
energies[x > self._upper_bound] = np.infty
energies[x < self._lower_bound] = np.inf
energies[x > self._upper_bound] = np.inf
return 0.5 * energies.sum(dim=-1, keepdim=True)

def icdf(self, x):
Expand Down
2 changes: 1 addition & 1 deletion bgflow/factory/icmarginals.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(
bond_mu=1.0,
bond_sigma=1.0,
bond_lower=1e-5,
bond_upper=np.infty,
bond_upper=np.inf,
angle_mu=0.5,
angle_sigma=1.0,
angle_lower=1e-5,
Expand Down
2 changes: 1 addition & 1 deletion bgflow/nn/flow/cdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def __init__(
mu,
sigma=torch.tensor(1.0),
lower_bound=0.0,
upper_bound=np.infty,
upper_bound=np.inf,
assert_range=True,
mu_out=None,
sigma_out=None,
Expand Down
2 changes: 1 addition & 1 deletion bgflow/utils/free_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def _bar(forward_work, reverse_work, compute_uncertainty=True, maximum_iteration
f_upper_bound = _bar_zero(forward_work, reverse_work, upper_bound)
f_lower_bound = _bar_zero(forward_work, reverse_work, lower_bound)

delta_f_old = np.infty
delta_f_old = np.inf
for iterations in range(maximum_iterations):
delta_f = upper_bound - f_upper_bound * (upper_bound - lower_bound) / (f_upper_bound - f_lower_bound)
f_new = _bar_zero(forward_work, reverse_work, delta_f)
Expand Down
2 changes: 1 addition & 1 deletion tests/factory/test_generator_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_builder_add_layer_and_param_groups(ctx):
# transform some fields
builder.add_layer(
CDFTransform(
TruncatedNormalDistribution(torch.zeros(10, **ctx), lower_bound=-torch.tensor(np.infty)),
TruncatedNormalDistribution(torch.zeros(10, **ctx), lower_bound=-torch.tensor(np.inf)),
),
what=[BONDS],
inverse=True,
Expand Down