Skip to content

Commit 173eb24

Browse files
authored
Stricter tests for model inputs (#142)
1 parent 40b1661 commit 173eb24

File tree

2 files changed

+21
-9
lines changed

2 files changed

+21
-9
lines changed

python/metatomic_torch/metatomic/torch/ase_calculator.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -985,14 +985,24 @@ def _get_ase_input(
985985
)
986986

987987
values = infos["getter"](atoms)
988-
values = torch.tensor(
989-
values[:, :, None] if values.ndim == 2 else values[None, :, None]
990-
)
988+
if values.shape[0] != len(atoms):
989+
raise NotImplementedError(
990+
f"The model requested the '{name}' input, "
991+
f"but the data is not per-atom (shape {values.shape}). "
992+
)
993+
# Shape: (n_atoms, n_components) -> (n_atoms, n_components, /* n_properties */ 1)
994+
# for metatensor
995+
values = torch.tensor(values[..., None])
991996

992997
tblock = TensorBlock(
993998
values,
994-
samples=Labels.range("atoms", values.shape[0]),
995-
components=[Labels.range("components", values.shape[1])]
999+
samples=Labels(
1000+
["system", "atom"],
1001+
torch.vstack(
1002+
[torch.full((values.shape[0],), 0), torch.arange(values.shape[0])]
1003+
).T,
1004+
),
1005+
components=[Labels(["xyz"], torch.arange(values.shape[1]).reshape(-1, 1))]
9961006
if values.shape[1] != 1
9971007
else [],
9981008
properties=Labels(

python/metatomic_torch/tests/ase_calculator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -853,17 +853,19 @@ def test_additional_input(atoms):
853853
)
854854
MaxwellBoltzmannDistribution(atoms, temperature_K=300.0)
855855
atoms.set_initial_charges([0.0] * len(atoms))
856-
calculator = MetatomicCalculator(model)
856+
calculator = MetatomicCalculator(model, check_consistency=True)
857857
results = calculator.run_model(atoms, outputs)
858858
for k, v in results.items():
859859
head, prop = k.split("::", maxsplit=1)
860860
assert head == "extra"
861861
assert prop in inputs
862862
assert len(v.keys.names) == 1
863863
assert v.get_info("quantity") == inputs[prop].quantity
864-
shape = v[0].values.numpy().shape
864+
values = v[0].values.numpy()
865+
shape = values.shape
866+
assert shape[0] == len(atoms), f"Expected {len(atoms)} values, got {shape[0]}"
865867
assert np.allclose(
866-
v[0].values.numpy(),
868+
values,
867869
ARRAY_QUANTITIES[prop]["getter"](atoms).reshape(shape)
868-
* (10 if prop == "velocity" else 1), # ase velocity is in nm/fs
870+
* (10 if prop == "velocity" else 1), # ase velocity is in nm/fs, not A/fs
869871
)

0 commit comments

Comments
 (0)