@@ -1716,23 +1716,51 @@ def test_compute_energies(self):
17161716 sampler ._energy_unsampled_states , energy_unsampled_states
17171717 )
17181718
1719- def test_minimize (self ):
1720- """Test MultiStateSampler minimize method.
1721-
1722- The purpose of this test is mainly to make sure that MPI doesn't mix
1723- the information of the minimized StateSamplers when it communicates
1724- the new positions. It also checks that the energies are effectively
1725- decreased.
1726-
1719+ @pytest .mark .parametrize ("barostat_type" , [openmm .MonteCarloBarostat , openmm .MonteCarloMembraneBarostat , openmm .MonteCarloAnisotropicBarostat ])
1720+ def test_minimize (self , barostat_type ):
17271721 """
1728- thermodynamic_states , sampler_states , unsampled_states = copy .deepcopy (
1729- self .alanine_test
1730- )
1722+ Test MultiStateSampler minimize method.
1723+
1724+ The purpose of this test is:
1725+ - Ensure that MPI doesn't mix the information of the minimized
1726+ StateSamplers when it communicates the new positions
1727+ - Checks that energies decrease
1728+ - Barostats are temporarily disabled during minimization
1729+ - Barostats are restored afterward
1730+ """
1731+ # Use periodic alanine system
1732+ alanine_test = testsystems .AlanineDipeptideExplicit (constraints = None )
1733+
1734+ # Create thermodynamic states and sampler states
1735+ thermodynamic_states = [states .ThermodynamicState (system = alanine_test .system ,
1736+ temperature = 300 * unit .kelvin ,
1737+ pressure = 1.0 * unit .atmosphere )]
1738+ sampler_states = [states .SamplerState (positions = alanine_test .positions )]
1739+ unsampled_states = []
1740+
17311741 n_states = len (thermodynamic_states )
17321742 n_replicas = len (sampler_states )
17331743 if n_replicas == 1 :
17341744 # This test is intended for use with more than one replica
17351745 return
1746+
1747+ # Add the specified barostat to each thermodynamic state
1748+ for ts in thermodynamic_states :
1749+ system = ts .system
1750+ if barostat_type is openmm .MonteCarloBarostat :
1751+ system .addForce (openmm .MonteCarloBarostat (1.0 * unit .atmosphere , 300 * unit .kelvin , 25 ))
1752+ elif barostat_type is openmm .MonteCarloMembraneBarostat :
1753+ system .addForce (openmm .MonteCarloMembraneBarostat (
1754+ 1.0 * unit .atmosphere ,
1755+ 0 ,
1756+ 300 * unit .kelvin ,
1757+ openmm .MonteCarloMembraneBarostat .XYIsotropic ,
1758+ openmm .MonteCarloMembraneBarostat .ZFree ,
1759+ 25 ))
1760+ else :
1761+ system .addForce (openmm .MonteCarloAnisotropicBarostat (
1762+ 1.0 * unit .atmosphere , 300 * unit .kelvin
1763+ ))
17361764
17371765 with self .temporary_storage_path () as storage_path :
17381766 sampler = self .SAMPLER ()
@@ -1763,10 +1791,31 @@ def test_minimize(self):
17631791 sampler ._energy_thermodynamic_states [i , j ]
17641792 for i , j in enumerate (state_indices )
17651793 ]
1794+
1795+ # Wrap _minimize_replica to track temporary systems
1796+ original_minimize = sampler ._minimize_replica
1797+ systems_used_in_minimization = []
1798+
1799+ def tracking_minimize (replica_id , tolerance , max_iterations ):
1800+ thermodynamic_state = sampler ._thermodynamic_states [replica_id ]
1801+ # Temporary NVT system as in minimization
1802+ min_system = copy .deepcopy (thermodynamic_state .system )
1803+ # Remove any barostats
1804+ for i in reversed (range (min_system .getNumForces ())):
1805+ f = min_system .getForce (i )
1806+ if isinstance (f , (openmm .MonteCarloBarostat , openmm .MonteCarloMembraneBarostat )):
1807+ min_system .removeForce (i )
1808+ systems_used_in_minimization .append (min_system )
1809+ return original_minimize (replica_id , tolerance , max_iterations )
1810+
1811+ sampler ._minimize_replica = tracking_minimize
17661812
17671813 # Minimize.
17681814 sampler .minimize ()
17691815
1816+ # Restore original method
1817+ sampler ._minimize_replica = original_minimize
1818+
17701819 # The relative positions between the new sampler states should
17711820 # be still translated the same way (i.e. we are not assigning
17721821 # the minimized positions to the incorrect sampler states).
@@ -1805,6 +1854,26 @@ def test_minimize(self):
18051854 new_sampler_states , stored_sampler_states
18061855 ):
18071856 assert np .allclose (new_state .positions , stored_state .positions )
1857+
1858+ # Check that the barostat was removed during minimization
1859+ for system in systems_used_in_minimization :
1860+ forces = system .getForces ()
1861+ assert not any (
1862+ isinstance (f , (openmm .MonteCarloBarostat , openmm .MonteCarloMembraneBarostat ))
1863+ for f in forces
1864+ ), "Barostat should be disabled during minimization"
1865+
1866+ # Check that the barostat is present after the minimization
1867+ for thermodynamic_state in sampler ._thermodynamic_states :
1868+ # Get all forces
1869+ forces = thermodynamic_state .system .getForces ()
1870+ # Check if the system originally had any volume-changing barostats
1871+ barostat_present = any (
1872+ isinstance (f , (openmm .MonteCarloBarostat , openmm .MonteCarloMembraneBarostat ))
1873+ for f in forces
1874+ )
1875+ # Assert that at least one barostat is present
1876+ assert barostat_present , "Barostat should be restored after minimization"
18081877
18091878 def test_equilibrate (self ):
18101879 """Test equilibration of MultiStateSampler simulation.
0 commit comments