Skip to content

Commit a81f28a

Browse files
Allow for any barostat in SystemGenerator (openmm#414)
* Allow for any barostat in SystemGenerator * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix doc strings * Add test for MonteCarloMembraneBarostat * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 839f225 commit a81f28a

File tree

2 files changed

+55
-28
lines changed

2 files changed

+55
-28
lines changed

openmmforcefields/generators/system_generators.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -303,25 +303,15 @@ def _modify_forces(self, system):
303303
"""
304304
Add barostat and modify forces if requested.
305305
"""
306-
307306
# Add barostat if requested and the system uses periodic boundary conditions
308307
if (self.barostat is not None) and system.usesPeriodicBoundaryConditions():
309308
import numpy as np
310-
import openmm
309+
import copy
311310

312311
MAXINT = np.iinfo(np.int32).max
313312

314-
# Determine pressure, temperature, and frequency
315-
pressure = self.barostat.getDefaultPressure()
316-
if hasattr(self.barostat, "getDefaultTemperature"):
317-
temperature = self.barostat.getDefaultTemperature()
318-
else:
319-
temperature = self.barostat.getTemperature()
320-
frequency = self.barostat.getFrequency()
321-
322-
# Create the barostat
323-
# TODO: Make sure we can support other kinds of barostats?
324-
barostat = openmm.MonteCarloBarostat(pressure, temperature, frequency)
313+
# Get the barostat
314+
barostat = copy.deepcopy(self.barostat)
325315
seed = np.random.randint(MAXINT)
326316
barostat.setRandomNumberSeed(seed)
327317
system.addForce(barostat)

openmmforcefields/tests/test_system_generator.py

Lines changed: 52 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
from openff.toolkit.topology import Molecule
99
from openmm.app import LJPME, PME, CutoffNonPeriodic, Modeller, PDBFile
10+
from openmm import unit, MonteCarloBarostat, MonteCarloMembraneBarostat
1011

1112
from openmmforcefields.generators import SystemGenerator
1213
from openmmforcefields.utils import Timer, get_data_filename
@@ -139,21 +140,52 @@ def test_create(self):
139140
# Create an empty system generator
140141
SystemGenerator()
141142

142-
def test_barostat(self):
143-
"""Test that barostat addition works correctly"""
143+
@pytest.mark.parametrize(
144+
"barostat_class, args",
145+
[
146+
# MonteCarloBarostat
147+
(
148+
MonteCarloBarostat,
149+
[0.95 * unit.atmospheres, 301.0 * unit.kelvin, 23],
150+
),
151+
# MonteCarloMembraneBarostat
152+
(
153+
MonteCarloMembraneBarostat,
154+
[
155+
1.0 * unit.atmospheres,
156+
10.0 * unit.millinewton / unit.meter,
157+
301.0 * unit.kelvin,
158+
MonteCarloMembraneBarostat.XYIsotropic,
159+
MonteCarloMembraneBarostat.ZFree,
160+
23,
161+
],
162+
),
163+
],
164+
)
165+
def test_barostat(self, barostat_class, args):
166+
"""Test that different barostats are correctly applied to the system"""
144167
# Create a protein SystemGenerator
145168
generator = SystemGenerator(forcefields=self.amber_forcefields)
146169

147-
# Create a template barostat
148-
from openmm import MonteCarloBarostat, unit
149-
150-
pressure = 0.95 * unit.atmospheres
151-
temperature = 301.0 * unit.kelvin
152-
frequency = 23
153-
generator.barostat = MonteCarloBarostat(pressure, temperature, frequency)
170+
# Create the barostat
171+
generator.barostat = barostat_class(*args)
172+
173+
# Derive expected values based on barostat type
174+
if barostat_class is MonteCarloBarostat:
175+
expected = {
176+
"pressure": args[0],
177+
"temperature": args[1],
178+
"frequency": args[2],
179+
}
180+
else: # MonteCarloMembraneBarostat
181+
expected = {
182+
"pressure": args[0],
183+
"surface_tension": args[1],
184+
"temperature": args[2],
185+
"frequency": args[-1],
186+
}
154187

155188
# Load a PDB file
156-
157189
pdb_filename = get_data_filename(os.path.join("perses_jacs_systems", "mcl1", "MCL1_protein.pdb"))
158190
pdbfile = PDBFile(pdb_filename)
159191

@@ -176,13 +208,18 @@ def test_barostat(self):
176208

177209
# Check barostat is present
178210
forces = {force.__class__.__name__: force for force in system.getForces()}
179-
assert "MonteCarloBarostat" in forces.keys()
211+
name = barostat_class.__name__
212+
assert name in forces, f"{name} not found in system forces"
180213

181214
# Check barostat parameters
182-
force = forces["MonteCarloBarostat"]
183-
assert force.getDefaultPressure() == pressure
184-
assert force.getDefaultTemperature() == temperature
185-
assert force.getFrequency() == frequency
215+
force = forces[name]
216+
assert force.getDefaultTemperature() == expected["temperature"]
217+
assert force.getDefaultPressure() == expected["pressure"]
218+
assert force.getFrequency() == expected["frequency"]
219+
220+
# Conditional check
221+
if hasattr(force, "getDefaultSurfaceTension"):
222+
assert force.getDefaultSurfaceTension() == expected["surface_tension"]
186223

187224
@pytest.mark.parametrize(
188225
"small_molecule_forcefield",

0 commit comments

Comments
 (0)