diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index 5381b08a..09cd0390 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -985,14 +985,24 @@ def _get_ase_input( ) values = infos["getter"](atoms) - values = torch.tensor( - values[:, :, None] if values.ndim == 2 else values[None, :, None] - ) + if values.shape[0] != len(atoms): + raise NotImplementedError( + f"The model requested the '{name}' input, " + f"but the data is not per-atom (shape {values.shape}). " + ) + # Shape: (n_atoms, n_components) -> (n_atoms, n_components, /* n_properties */ 1) + # for metatensor + values = torch.tensor(values[..., None]) tblock = TensorBlock( values, - samples=Labels.range("atoms", values.shape[0]), - components=[Labels.range("components", values.shape[1])] + samples=Labels( + ["system", "atom"], + torch.vstack( + [torch.full((values.shape[0],), 0), torch.arange(values.shape[0])] + ).T, + ), + components=[Labels(["xyz"], torch.arange(values.shape[1]).reshape(-1, 1))] if values.shape[1] != 1 else [], properties=Labels( diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index f1a487e6..084cee9a 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -853,7 +853,7 @@ def test_additional_input(atoms): ) MaxwellBoltzmannDistribution(atoms, temperature_K=300.0) atoms.set_initial_charges([0.0] * len(atoms)) - calculator = MetatomicCalculator(model) + calculator = MetatomicCalculator(model, check_consistency=True) results = calculator.run_model(atoms, outputs) for k, v in results.items(): head, prop = k.split("::", maxsplit=1) @@ -861,9 +861,11 @@ def test_additional_input(atoms): assert prop in inputs assert len(v.keys.names) == 1 assert v.get_info("quantity") == inputs[prop].quantity - shape = v[0].values.numpy().shape + values = v[0].values.numpy() + shape = values.shape + assert shape[0] == len(atoms), f"Expected {len(atoms)} values, got {shape[0]}" assert np.allclose( - v[0].values.numpy(), + values, ARRAY_QUANTITIES[prop]["getter"](atoms).reshape(shape) - * (10 if prop == "velocity" else 1), # ase velocity is in nm/fs + * (10 if prop == "velocity" else 1), # ase velocity is in nm/fs, not A/fs )