77import pytest
88from openff .toolkit .topology import Molecule
99from openmm .app import LJPME , PME , CutoffNonPeriodic , Modeller , PDBFile
10+ from openmm import unit , MonteCarloBarostat , MonteCarloMembraneBarostat
1011
1112from openmmforcefields .generators import SystemGenerator
1213from openmmforcefields .utils import Timer , get_data_filename
@@ -139,21 +140,52 @@ def test_create(self):
139140 # Create an empty system generator
140141 SystemGenerator ()
141142
142- def test_barostat (self ):
143- """Test that barostat addition works correctly"""
143+ @pytest .mark .parametrize (
144+ "barostat_class, args" ,
145+ [
146+ # MonteCarloBarostat
147+ (
148+ MonteCarloBarostat ,
149+ [0.95 * unit .atmospheres , 301.0 * unit .kelvin , 23 ],
150+ ),
151+ # MonteCarloMembraneBarostat
152+ (
153+ MonteCarloMembraneBarostat ,
154+ [
155+ 1.0 * unit .atmospheres ,
156+ 10.0 * unit .millinewton / unit .meter ,
157+ 301.0 * unit .kelvin ,
158+ MonteCarloMembraneBarostat .XYIsotropic ,
159+ MonteCarloMembraneBarostat .ZFree ,
160+ 23 ,
161+ ],
162+ ),
163+ ],
164+ )
165+ def test_barostat (self , barostat_class , args ):
166+ """Test that different barostats are correctly applied to the system"""
144167 # Create a protein SystemGenerator
145168 generator = SystemGenerator (forcefields = self .amber_forcefields )
146169
147- # Create a template barostat
148- from openmm import MonteCarloBarostat , unit
149-
150- pressure = 0.95 * unit .atmospheres
151- temperature = 301.0 * unit .kelvin
152- frequency = 23
153- generator .barostat = MonteCarloBarostat (pressure , temperature , frequency )
170+ # Create the barostat
171+ generator .barostat = barostat_class (* args )
172+
173+ # Derive expected values based on barostat type
174+ if barostat_class is MonteCarloBarostat :
175+ expected = {
176+ "pressure" : args [0 ],
177+ "temperature" : args [1 ],
178+ "frequency" : args [2 ],
179+ }
180+ else : # MonteCarloMembraneBarostat
181+ expected = {
182+ "pressure" : args [0 ],
183+ "surface_tension" : args [1 ],
184+ "temperature" : args [2 ],
185+ "frequency" : args [- 1 ],
186+ }
154187
155188 # Load a PDB file
156-
157189 pdb_filename = get_data_filename (os .path .join ("perses_jacs_systems" , "mcl1" , "MCL1_protein.pdb" ))
158190 pdbfile = PDBFile (pdb_filename )
159191
@@ -176,13 +208,18 @@ def test_barostat(self):
176208
177209 # Check barostat is present
178210 forces = {force .__class__ .__name__ : force for force in system .getForces ()}
179- assert "MonteCarloBarostat" in forces .keys ()
211+ name = barostat_class .__name__
212+ assert name in forces , f"{ name } not found in system forces"
180213
181214 # Check barostat parameters
182- force = forces ["MonteCarloBarostat" ]
183- assert force .getDefaultPressure () == pressure
184- assert force .getDefaultTemperature () == temperature
185- assert force .getFrequency () == frequency
215+ force = forces [name ]
216+ assert force .getDefaultTemperature () == expected ["temperature" ]
217+ assert force .getDefaultPressure () == expected ["pressure" ]
218+ assert force .getFrequency () == expected ["frequency" ]
219+
220+ # Conditional check
221+ if hasattr (force , "getDefaultSurfaceTension" ):
222+ assert force .getDefaultSurfaceTension () == expected ["surface_tension" ]
186223
187224 @pytest .mark .parametrize (
188225 "small_molecule_forcefield" ,
0 commit comments