Skip to content

Commit 7df5226

Browse files
committed
Use MetatomicCalculator as the class name in torch profiler
1 parent 5de7922 commit 7df5226

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

python/examples/4-profiling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,15 +183,15 @@ def forward(
183183
# with the same function name as the corresponding torch functions (e.g.
184184
# ``aten::arange`` is :py:func:`torch.arange`). We can also see some internal functions
185185
# from metatomic, with the name staring with ``AtomisticModel::`` for
186-
# :py:class:`AtomisticModel`; and ``ASECalculator::`` for
186+
# :py:class:`AtomisticModel`; and ``MetatomicCalculator::`` for
187187
# :py:class:`ase_calculator.MetatomicCalculator`.
188188
#
189189
# If you want to see more details on the internal steps taken by your model, you can add
190190
# :py:func:`torch.profiler.record_function`
191191
# (https://pytorch.org/docs/stable/generated/torch.autograd.profiler.record_function.html)
192192
# inside your model code to give names to different steps in the calculation. This is
193193
# how we are internally adding names such as ``Model::forward`` or
194-
# ``ASECalculator::prepare_inputs`` above.
194+
# ``MetatomicCalculator::prepare_inputs`` above.
195195
#
196196

197197
# %%

python/metatomic_torch/metatomic/torch/ase_calculator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ def calculate(
327327
if "stresses" in properties:
328328
raise NotImplementedError("'stresses' are not implemented yet")
329329

330-
with record_function("ASECalculator::prepare_inputs"):
330+
with record_function("MetatomicCalculator::prepare_inputs"):
331331
outputs = self._ase_properties_to_metatensor_outputs(
332332
properties,
333333
calculate_forces=calculate_forces,
@@ -371,7 +371,7 @@ def calculate(
371371
selected_atoms=None,
372372
)
373373

374-
with record_function("ASECalculator::compute_neighbors"):
374+
with record_function("MetatomicCalculator::compute_neighbors"):
375375
# convert from ase.Atoms to metatomic.torch.System
376376
system = System(types, positions, cell, pbc)
377377

@@ -394,7 +394,7 @@ def calculate(
394394
)
395395
energy = outputs["energy"]
396396

397-
with record_function("ASECalculator::sum_energies"):
397+
with record_function("MetatomicCalculator::sum_energies"):
398398
if run_options.outputs["energy"].per_atom:
399399
assert len(energy) == 1
400400
assert energy.sample_names == ["system", "atom"]
@@ -409,11 +409,11 @@ def calculate(
409409
assert len(energy.block().gradients_list()) == 0
410410
assert energy.block().values.shape == (1, 1)
411411

412-
with record_function("ASECalculator::run_backward"):
412+
with record_function("MetatomicCalculator::run_backward"):
413413
if do_backward:
414414
energy.block().values.backward()
415415

416-
with record_function("ASECalculator::convert_outputs"):
416+
with record_function("MetatomicCalculator::convert_outputs"):
417417
self.results = {}
418418

419419
if calculate_energies:

0 commit comments

Comments
 (0)