Skip to content

Commit 952e4cd

Browse files
committed
Default to computing energy gradients everytime the energy is requested
1 parent 7df5226 commit 952e4cd

File tree

2 files changed

+53
-14
lines changed

2 files changed

+53
-14
lines changed

python/examples/4-profiling.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -134,33 +134,37 @@ def forward(
134134
# </details>
135135
#
136136

137-
138137
# %%
139138
#
140-
# If you are trying to profile your own model, you can start here and create a
139+
# If you are trying to profile your own model, you can start here and create
141140
# ``MetatomicCalculator`` with your own model.
142141

142+
# %%
143+
#
144+
# Profiling energy calculation
145+
# ----------------------------
146+
#
147+
# We will start with an energy-only calculator, which can be enabled with
148+
# ``do_gradients_with_energy=False``.
143149

144-
atoms.calc = MetatomicCalculator("exported-model.pt")
150+
atoms.calc = MetatomicCalculator("exported-model.pt", do_gradients_with_energy=False)
145151

146152
# %%
147153
#
148154
# Before trying to profile the code, it is a good idea to run it a couple of times to
149155
# allow torch to warmup internally.
150156

151-
atoms.get_forces()
152-
atoms.get_potential_energy()
157+
for _ in range(10):
158+
# force the model to re-run everytime, otherwise ASE caches calculation results
159+
atoms.rattle(1e-6)
160+
atoms.get_potential_energy()
153161

154162
# %%
155163
#
156-
# Profiling energy calculation
157-
# ----------------------------
158-
#
159164
# Now we can run code using :py:func:`torch.profiler.profile` to collect statistic on
160-
# how long each function takes to run. We randomize the positions to force ASE to
161-
# recompute the energy of the system
165+
# how long each function takes to run.
162166

163-
atoms.positions += np.random.rand(*atoms.positions.shape)
167+
atoms.rattle(1e-6)
164168
with torch.profiler.profile() as energy_profiler:
165169
atoms.get_potential_energy()
166170

@@ -202,7 +206,14 @@ def forward(
202206
# Let's now do the same, but computing the forces for this system. This mean we should
203207
# now see some time spent in the ``backward()`` function, on top of everything else.
204208

205-
atoms.positions += np.random.rand(*atoms.positions.shape)
209+
atoms.calc = MetatomicCalculator("exported-model.pt")
210+
211+
# warmup
212+
for _ in range(10):
213+
atoms.rattle(1e-6)
214+
atoms.get_forces()
215+
216+
atoms.rattle(1e-6)
206217
with torch.profiler.profile() as forces_profiler:
207218
atoms.get_forces()
208219

python/metatomic_torch/metatomic/torch/ase_calculator.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def __init__(
6868
check_consistency=False,
6969
device=None,
7070
non_conservative=False,
71+
do_gradients_with_energy=True,
7172
):
7273
"""
7374
:param model: model to use for the calculation. This can be a file path, a
@@ -87,10 +88,18 @@ def __init__(
8788
running, defaults to False.
8889
:param device: torch device to use for the calculation. If ``None``, we will try
8990
the options in the model's ``supported_device`` in order.
90-
:param non_conservative: if ``True``, the model will be asked to
91-
compute non-conservative forces and stresses. This can afford a speed-up,
91+
:param non_conservative: if ``True``, the model will be asked to compute
92+
non-conservative forces and stresses. This can afford a speed-up,
9293
potentially at the expense of physical correctness (especially in molecular
9394
dynamics simulations).
95+
:param do_gradients_with_energy: if ``True``, this calculator will always
96+
compute the energy gradients (forces and stress) when the energy is
97+
requested (e.g. through ``atoms.get_potential_energy()``). Because the
98+
results of a calculation are cached by ASE, this means future calls to
99+
``atom.get_forces()`` will return immediately, without needing to execute
100+
the model again. If you are mainly interested in the energy, you can set
101+
this to ``False`` and enjoy a faster model. Forces will still be calculated
102+
if requested with ``atoms.get_forces()``.
94103
"""
95104
super().__init__()
96105

@@ -175,6 +184,7 @@ def __init__(
175184
self._device = device
176185
self._model = model.to(device=self._device)
177186
self._non_conservative = non_conservative
187+
self._do_gradients_with_energy = do_gradients_with_energy
178188

179189
# We do our own check to verify if a property is implemented in `calculate()`,
180190
# so we pretend to be able to compute all properties ASE knows about.
@@ -327,6 +337,11 @@ def calculate(
327337
if "stresses" in properties:
328338
raise NotImplementedError("'stresses' are not implemented yet")
329339

340+
if self._do_gradients_with_energy:
341+
if calculate_energies or calculate_energy:
342+
calculate_forces = True
343+
calculate_stress = True
344+
330345
with record_function("MetatomicCalculator::prepare_inputs"):
331346
outputs = self._ase_properties_to_metatensor_outputs(
332347
properties,
@@ -409,6 +424,19 @@ def calculate(
409424
assert len(energy.block().gradients_list()) == 0
410425
assert energy.block().values.shape == (1, 1)
411426

427+
if do_backward:
428+
if energy.block().values.grad_fn is None:
429+
# did the user actually request a gradient, or are we trying to
430+
# compute one just for efficiency?
431+
if "forces" in properties or "stress" in properties:
432+
# the user asked for it, let it fail below
433+
pass
434+
else:
435+
# we added the calculation, let's remove it
436+
do_backward = False
437+
calculate_forces = False
438+
calculate_stress = False
439+
412440
with record_function("MetatomicCalculator::run_backward"):
413441
if do_backward:
414442
energy.block().values.backward()

0 commit comments

Comments
 (0)