2727
2828logger = logging .getLogger ("mlipaudit" )
2929
30- DEFAULT_ASE_MAX_FORCE_CONV_THRESH = 0.01
31-
3230
3331class ASESimulationEngineWithCalculator (ASESimulationEngine ):
3432 """Class derived from mlip's ASE simulation engine but allowing for a passed
@@ -73,9 +71,15 @@ def __init__(
7371
7472def get_simulation_engine (
7573 atoms : ase .Atoms , force_field : ForceField | ASECalculator , ** kwargs
76- ) -> JaxMDSimulationEngine | ASESimulationEngineWithCalculator :
74+ ) -> JaxMDSimulationEngine | ASESimulationEngineWithCalculator | ASESimulationEngine :
7775 """Returns the correct simulation engine based on the input force field type.
7876
77+ For MD simulations with `mlip.models.ForceField` objects, we return a
78+ `JaxMDSimulationEngine`. For energy minimizations with those objects, we return
79+ a `ASESimulationEngine`. For any type of simulations with ASE calculator objects,
80+ we return a `ASESimulationEngineWithCalculator`, which is a custom class of
81+ the `mlipaudit` library.
82+
7983 Args:
8084 atoms: The ASE atoms.
8185 force_field: The force field, either an `mlip.models.ForceField`
@@ -89,21 +93,26 @@ def get_simulation_engine(
8993 Raises:
9094 ValueError: If force field type is not compatible.
9195 """
92- if isinstance (force_field , ForceField ):
96+ # Case 1: MD simulations with ForceField objects -> use JAX-MD
97+ if (
98+ isinstance (force_field , ForceField )
99+ and kwargs .get ("simulation_type" , "md" ) == "md"
100+ ):
93101 md_config = JaxMDSimulationEngine .Config (** kwargs )
94102 return JaxMDSimulationEngine (atoms , force_field , md_config )
95103
96- elif isinstance (force_field , ASECalculator ):
97- kwargs_copy = deepcopy (kwargs )
98- kwargs_copy .pop ("num_episodes" , None ) # remove this if exists
104+ kwargs_copy = deepcopy (kwargs )
105+ kwargs_copy .pop ("num_episodes" , None ) # remove this if exists
99106
100- # for minimization:
101- kwargs_copy [ "max_force_convergence_threshold" ] = (
102- DEFAULT_ASE_MAX_FORCE_CONV_THRESH
103- )
107+ # Case 2: Minimization with ForceField objects -> use ASE
108+ if isinstance ( force_field , ForceField ):
109+ minimization_config = ASESimulationEngine . Config ( ** kwargs_copy )
110+ return ASESimulationEngine ( atoms , force_field , minimization_config )
104111
105- md_config = ASESimulationEngine .Config (** kwargs_copy )
106- return ASESimulationEngineWithCalculator (atoms , force_field , md_config )
112+ # Case 3: MD or minimization with ASECalculator objects -> use ASE
113+ if isinstance (force_field , ASECalculator ):
114+ sim_config = ASESimulationEngine .Config (** kwargs_copy )
115+ return ASESimulationEngineWithCalculator (atoms , force_field , sim_config )
107116
108117 raise ValueError (
109118 "Provided force field must be either a mlip-compatible "
0 commit comments