@@ -68,6 +68,7 @@ def __init__(
6868 check_consistency = False ,
6969 device = None ,
7070 non_conservative = False ,
71+ do_gradients_with_energy = True ,
7172 ):
7273 """
7374 :param model: model to use for the calculation. This can be a file path, a
@@ -87,10 +88,18 @@ def __init__(
8788 running, defaults to False.
8889 :param device: torch device to use for the calculation. If ``None``, we will try
8990 the options in the model's ``supported_device`` in order.
90- :param non_conservative: if ``True``, the model will be asked to
91- compute non-conservative forces and stresses. This can afford a speed-up,
91+ :param non_conservative: if ``True``, the model will be asked to compute
92+ non-conservative forces and stresses. This can afford a speed-up,
9293 potentially at the expense of physical correctness (especially in molecular
9394 dynamics simulations).
95+ :param do_gradients_with_energy: if ``True``, this calculator will always
96+ compute the energy gradients (forces and stress) when the energy is
97+ requested (e.g. through ``atoms.get_potential_energy()``). Because the
98+ results of a calculation are cached by ASE, this means future calls to
99+ ``atom.get_forces()`` will return immediately, without needing to execute
100+ the model again. If you are mainly interested in the energy, you can set
101+ this to ``False`` and enjoy a faster model. Forces will still be calculated
102+ if requested with ``atoms.get_forces()``.
94103 """
95104 super ().__init__ ()
96105
@@ -175,6 +184,7 @@ def __init__(
175184 self ._device = device
176185 self ._model = model .to (device = self ._device )
177186 self ._non_conservative = non_conservative
187+ self ._do_gradients_with_energy = do_gradients_with_energy
178188
179189 # We do our own check to verify if a property is implemented in `calculate()`,
180190 # so we pretend to be able to compute all properties ASE knows about.
@@ -327,6 +337,11 @@ def calculate(
327337 if "stresses" in properties :
328338 raise NotImplementedError ("'stresses' are not implemented yet" )
329339
340+ if self ._do_gradients_with_energy :
341+ if calculate_energies or calculate_energy :
342+ calculate_forces = True
343+ calculate_stress = True
344+
330345 with record_function ("MetatomicCalculator::prepare_inputs" ):
331346 outputs = self ._ase_properties_to_metatensor_outputs (
332347 properties ,
@@ -409,6 +424,19 @@ def calculate(
409424 assert len (energy .block ().gradients_list ()) == 0
410425 assert energy .block ().values .shape == (1 , 1 )
411426
427+ if do_backward :
428+ if energy .block ().values .grad_fn is None :
429+ # did the user actually request a gradient, or are we trying to
430+ # compute one just for efficiency?
431+ if "forces" in properties or "stress" in properties :
432+ # the user asked for it, let it fail below
433+ pass
434+ else :
435+ # we added the calculation, let's remove it
436+ do_backward = False
437+ calculate_forces = False
438+ calculate_stress = False
439+
412440 with record_function ("MetatomicCalculator::run_backward" ):
413441 if do_backward :
414442 energy .block ().values .backward ()
0 commit comments