Skip to content

Commit b132ce6

Browse files
committed
Check the requested inputs with _check_outputs
1 parent 3edf931 commit b132ce6

File tree

3 files changed

+30
-20
lines changed

3 files changed

+30
-20
lines changed

python/metatomic_torch/metatomic/torch/ase_calculator.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,17 @@
4444
"float64": torch.float64,
4545
}
4646

47-
ARRAY_PROPERTIES = {
48-
"momenta": {
47+
ARRAY_QUANTITIES = {
48+
"momentum": {
4949
"getter": lambda atoms: atoms.get_momenta(),
50-
"quantity": "momentum",
5150
"unit": "(eV*u)^(1/2)",
5251
},
53-
"masses": {
52+
"mass": {
5453
"getter": lambda atoms: atoms.get_masses(),
55-
"quantity": "mass",
5654
"unit": "u",
5755
},
58-
"velocities": {
56+
"velocity": {
5957
"getter": lambda atoms: atoms.get_velocities(),
60-
"quantity": "velocity",
6158
"unit": "nm/fs",
6259
},
6360
"initial_magmoms": {},
@@ -368,7 +365,7 @@ def run_model(
368365
# Get the additional inputs requested by the model
369366
for quantity, option in self._model.requested_inputs().items():
370367
input_tensormap = _get_ase_input(
371-
atoms, quantity, option, dtype=self._dtype, device=self._device
368+
atoms, option, dtype=self._dtype, device=self._device
372369
)
373370
system.add_data(quantity, input_tensormap)
374371
systems.append(system)
@@ -519,7 +516,7 @@ def calculate(
519516
system.add_neighbor_list(options, neighbors)
520517
for quantity, option in self._model.requested_inputs().items():
521518
input_tensormap = _get_ase_input(
522-
atoms, quantity, option, dtype=self._dtype, device=self._device
519+
atoms, option, dtype=self._dtype, device=self._device
523520
)
524521
system.add_data(quantity, input_tensormap)
525522

@@ -944,21 +941,20 @@ def _compute_ase_neighbors(atoms, options, dtype, device):
944941

945942
def _get_ase_input(
946943
atoms: ase.Atoms,
947-
quantity: str,
948944
option: ModelOutput,
949945
dtype: torch.dtype,
950946
device: torch.device,
951947
) -> "TensorMap":
952-
if quantity in ARRAY_PROPERTIES:
953-
if len(ARRAY_PROPERTIES[quantity]) == 0:
948+
if option.quantity in ARRAY_QUANTITIES:
949+
if len(ARRAY_QUANTITIES[option.quantity]) == 0:
954950
raise NotImplementedError(
955-
f"Though the property {quantity} is available in `ase`, it is "
951+
f"Though the quantity {option.quantity} is available in `ase`, it is "
956952
"currently not supported by metatomic."
957953
)
958-
infos = ARRAY_PROPERTIES[quantity]
954+
infos = ARRAY_QUANTITIES[option.quantity]
959955
else:
960956
raise ValueError(
961-
f"The model requested '{quantity}', which is not available in `ase`."
957+
f"The model requested '{option.quantity}', which is not available in `ase`."
962958
)
963959

964960
values = infos["getter"](atoms)

python/metatomic_torch/metatomic/torch/model.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,18 @@ def forward(
438438
options=options,
439439
expected_dtype=self._model_dtype,
440440
)
441+
# check the requested inputs stored in the `systems`
442+
for system in systems:
443+
system_inputs: Dict[str, TensorMap] = {}
444+
for name in system.known_data():
445+
system_inputs[name] = system.get_data(name)
446+
_check_outputs(
447+
systems=[system],
448+
requested=self._requested_inputs,
449+
selected_atoms=options.selected_atoms,
450+
outputs=system_inputs,
451+
model_dtype=self._capabilities.dtype,
452+
)
441453

442454
with record_function("AtomisticModel::check_atomic_types"):
443455
# always (i.e. even if check_consistency=False) check that the atomic types
@@ -898,6 +910,8 @@ def _check_inputs(
898910
)
899911

900912
# Check additional inputs
913+
# Might be problematic, this requires that only requested inputs are stored as
914+
# the data pf the system
901915
known_additional_inputs = system.known_data()
902916
for request in requested_inputs:
903917
found = False

python/metatomic_torch/tests/ase_calculator.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
System,
2727
)
2828
from metatomic.torch.ase_calculator import (
29-
ARRAY_PROPERTIES,
29+
ARRAY_QUANTITIES,
3030
MetatomicCalculator,
3131
_compute_ase_neighbors,
3232
_full_3x3_to_voigt_6_stress,
@@ -835,8 +835,8 @@ def forward(
835835

836836
def test_additional_input(atoms):
837837
inputs = {
838-
"masses": ModelOutput(quantity="mass", unit="u", per_atom=True),
839-
"velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True),
838+
"mass": ModelOutput(quantity="mass", unit="u", per_atom=True),
839+
"velocity": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True),
840840
}
841841
outputs = {("extra::" + prop): inputs[prop] for prop in inputs}
842842
capabilities = ModelCapabilities(
@@ -862,6 +862,6 @@ def test_additional_input(atoms):
862862
shape = v[0].values.numpy().shape
863863
assert np.allclose(
864864
v[0].values.numpy(),
865-
ARRAY_PROPERTIES[prop]["getter"](atoms).reshape(shape)
866-
* (10 if prop == "velocities" else 1),
865+
ARRAY_QUANTITIES[prop]["getter"](atoms).reshape(shape)
866+
* (10 if prop == "velocity" else 1),
867867
)

0 commit comments

Comments
 (0)