Skip to content

Commit 16b62a2

Browse files
Remove the barostat force from the system during minimization (#798)
* Remove the barostat force from the system during minimization * Apply suggestion from @hannahbaumann * Apply suggestion from @hannahbaumann * Add test * Add MonteCarloAnisotropicBarostat
1 parent 7b6e86c commit 16b62a2

File tree

2 files changed

+112
-24
lines changed

2 files changed

+112
-24
lines changed

openmmtools/multistate/multistatesampler.py

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1357,29 +1357,52 @@ def _minimize_replica(self, replica_id, tolerance, max_iterations):
13571357
thermodynamic_state_id = self._replica_thermodynamic_states[replica_id]
13581358
thermodynamic_state = self._thermodynamic_states[thermodynamic_state_id]
13591359

1360-
# Temporarily disable the barostat during minimization.
1361-
# Otherwise, the minimizer will modify the box
1362-
# vectors and may cause instabilities.
1363-
pressure = thermodynamic_state.pressure
1364-
thermodynamic_state.pressure = None
13651360
sampler_state = self._sampler_states[replica_id]
1361+
1362+
# Determine whether we need a temporary NVT state
1363+
barostat_types = (
1364+
openmm.MonteCarloBarostat,
1365+
openmm.MonteCarloMembraneBarostat,
1366+
openmm.MonteCarloAnisotropicBarostat,
1367+
)
1368+
1369+
has_barostat = any(
1370+
isinstance(thermodynamic_state.system.getForce(i), barostat_types)
1371+
for i in range(thermodynamic_state.system.getNumForces())
1372+
)
13661373

1374+
if has_barostat:
1375+
# Deep copy system and remove all barostats
1376+
min_system = copy.deepcopy(thermodynamic_state.system)
1377+
for i in reversed(range(min_system.getNumForces())):
1378+
if isinstance(min_system.getForce(i), barostat_types):
1379+
min_system.removeForce(i)
1380+
1381+
# Temporary NVT ThermodynamicState for minimization
1382+
minimization_state = states.ThermodynamicState(
1383+
system=min_system,
1384+
temperature=thermodynamic_state.temperature,
1385+
pressure=None
1386+
)
1387+
else:
1388+
# Use original state if no barostat
1389+
minimization_state = thermodynamic_state
1390+
13671391
# Use the FIRE minimizer
13681392
integrator = FIREMinimizationIntegrator(tolerance=tolerance)
13691393

13701394
# Get context and bound integrator from energy_context_cache
1371-
context, integrator = self.energy_context_cache.get_context(thermodynamic_state, integrator)
1395+
context, integrator = self.energy_context_cache.get_context(minimization_state, integrator)
13721396
# inform of platform used in current context
13731397
logger.debug(f"{type(integrator).__name__}: Minimize using {context.getPlatform().getName()} platform.")
13741398

13751399
# Set initial positions and box vectors.
13761400
sampler_state.apply_to_context(context)
13771401

13781402
# Compute the initial energy of the system for logging.
1379-
initial_energy = thermodynamic_state.reduced_potential(context)
1403+
initial_energy = minimization_state.reduced_potential(context)
13801404
logger.debug('Replica {}/{}: initial energy {:8.3f}kT'.format(
13811405
replica_id + 1, self.n_replicas, initial_energy))
1382-
13831406
# Minimize energy.
13841407
try:
13851408
if max_iterations == 0:
@@ -1400,18 +1423,14 @@ def _minimize_replica(self, replica_id, tolerance, max_iterations):
14001423
# Get the minimized positions.
14011424
sampler_state.update_from_context(context)
14021425

1403-
# Restore the barostat
1404-
thermodynamic_state.pressure = pressure
1405-
14061426
# Compute the final energy of the system for logging.
1407-
final_energy = thermodynamic_state.reduced_potential(sampler_state)
1427+
final_energy = minimization_state.reduced_potential(sampler_state)
14081428
logger.debug('Replica {}/{}: final energy {:8.3f}kT'.format(
14091429
replica_id + 1, self.n_replicas, final_energy))
14101430
# TODO if energy > 0, use slower openmm minimizer
14111431

14121432
# Clean up the integrator
14131433
del context
1414-
14151434
# Return minimized positions.
14161435
return sampler_state.positions
14171436

openmmtools/tests/test_sampling.py

Lines changed: 80 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1716,23 +1716,51 @@ def test_compute_energies(self):
17161716
sampler._energy_unsampled_states, energy_unsampled_states
17171717
)
17181718

1719-
def test_minimize(self):
1720-
"""Test MultiStateSampler minimize method.
1721-
1722-
The purpose of this test is mainly to make sure that MPI doesn't mix
1723-
the information of the minimized StateSamplers when it communicates
1724-
the new positions. It also checks that the energies are effectively
1725-
decreased.
1726-
1719+
@pytest.mark.parametrize("barostat_type", [openmm.MonteCarloBarostat, openmm.MonteCarloMembraneBarostat, openmm.MonteCarloAnisotropicBarostat])
1720+
def test_minimize(self, barostat_type):
17271721
"""
1728-
thermodynamic_states, sampler_states, unsampled_states = copy.deepcopy(
1729-
self.alanine_test
1730-
)
1722+
Test MultiStateSampler minimize method.
1723+
1724+
The purpose of this test is:
1725+
- Ensure that MPI doesn't mix the information of the minimized
1726+
StateSamplers when it communicates the new positions
1727+
- Checks that energies decrease
1728+
- Barostats are temporarily disabled during minimization
1729+
- Barostats are restored afterward
1730+
"""
1731+
# Use periodic alanine system
1732+
alanine_test = testsystems.AlanineDipeptideExplicit(constraints=None)
1733+
1734+
# Create thermodynamic states and sampler states
1735+
thermodynamic_states = [states.ThermodynamicState(system=alanine_test.system,
1736+
temperature=300*unit.kelvin,
1737+
pressure=1.0*unit.atmosphere)]
1738+
sampler_states = [states.SamplerState(positions=alanine_test.positions)]
1739+
unsampled_states = []
1740+
17311741
n_states = len(thermodynamic_states)
17321742
n_replicas = len(sampler_states)
17331743
if n_replicas == 1:
17341744
# This test is intended for use with more than one replica
17351745
return
1746+
1747+
# Add the specified barostat to each thermodynamic state
1748+
for ts in thermodynamic_states:
1749+
system = ts.system
1750+
if barostat_type is openmm.MonteCarloBarostat:
1751+
system.addForce(openmm.MonteCarloBarostat(1.0*unit.atmosphere, 300*unit.kelvin, 25))
1752+
elif barostat_type is openmm.MonteCarloMembraneBarostat:
1753+
system.addForce(openmm.MonteCarloMembraneBarostat(
1754+
1.0*unit.atmosphere,
1755+
0,
1756+
300*unit.kelvin,
1757+
openmm.MonteCarloMembraneBarostat.XYIsotropic,
1758+
openmm.MonteCarloMembraneBarostat.ZFree,
1759+
25))
1760+
else:
1761+
system.addForce(openmm.MonteCarloAnisotropicBarostat(
1762+
1.0 * unit.atmosphere, 300 * unit.kelvin
1763+
))
17361764

17371765
with self.temporary_storage_path() as storage_path:
17381766
sampler = self.SAMPLER()
@@ -1763,10 +1791,31 @@ def test_minimize(self):
17631791
sampler._energy_thermodynamic_states[i, j]
17641792
for i, j in enumerate(state_indices)
17651793
]
1794+
1795+
# Wrap _minimize_replica to track temporary systems
1796+
original_minimize = sampler._minimize_replica
1797+
systems_used_in_minimization = []
1798+
1799+
def tracking_minimize(replica_id, tolerance, max_iterations):
1800+
thermodynamic_state = sampler._thermodynamic_states[replica_id]
1801+
# Temporary NVT system as in minimization
1802+
min_system = copy.deepcopy(thermodynamic_state.system)
1803+
# Remove any barostats
1804+
for i in reversed(range(min_system.getNumForces())):
1805+
f = min_system.getForce(i)
1806+
if isinstance(f, (openmm.MonteCarloBarostat, openmm.MonteCarloMembraneBarostat)):
1807+
min_system.removeForce(i)
1808+
systems_used_in_minimization.append(min_system)
1809+
return original_minimize(replica_id, tolerance, max_iterations)
1810+
1811+
sampler._minimize_replica = tracking_minimize
17661812

17671813
# Minimize.
17681814
sampler.minimize()
17691815

1816+
# Restore original method
1817+
sampler._minimize_replica = original_minimize
1818+
17701819
# The relative positions between the new sampler states should
17711820
# be still translated the same way (i.e. we are not assigning
17721821
# the minimized positions to the incorrect sampler states).
@@ -1805,6 +1854,26 @@ def test_minimize(self):
18051854
new_sampler_states, stored_sampler_states
18061855
):
18071856
assert np.allclose(new_state.positions, stored_state.positions)
1857+
1858+
# Check that the barostat was removed during minimization
1859+
for system in systems_used_in_minimization:
1860+
forces = system.getForces()
1861+
assert not any(
1862+
isinstance(f, (openmm.MonteCarloBarostat, openmm.MonteCarloMembraneBarostat))
1863+
for f in forces
1864+
), "Barostat should be disabled during minimization"
1865+
1866+
# Check that the barostat is present after the minimization
1867+
for thermodynamic_state in sampler._thermodynamic_states:
1868+
# Get all forces
1869+
forces = thermodynamic_state.system.getForces()
1870+
# Check if the system originally had any volume-changing barostats
1871+
barostat_present = any(
1872+
isinstance(f, (openmm.MonteCarloBarostat, openmm.MonteCarloMembraneBarostat))
1873+
for f in forces
1874+
)
1875+
# Assert that at least one barostat is present
1876+
assert barostat_present, "Barostat should be restored after minimization"
18081877

18091878
def test_equilibrate(self):
18101879
"""Test equilibration of MultiStateSampler simulation.

0 commit comments

Comments
 (0)