Skip to content

Commit b8ec760

Browse files
committed
Fix a bug with custom model inputs in ASE
The unit conversion was using quantity instead of name when looking up the data
1 parent 749ab17 commit b8ec760

File tree

3 files changed

+98
-106
lines changed

3 files changed

+98
-106
lines changed

python/metatomic_torch/metatomic/torch/ase_calculator.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -60,31 +60,32 @@
6060
"unit": "nm/fs",
6161
},
6262
"ase::initial_magmoms": {
63-
"quantity": "",
63+
"quantity": "magnetic_moment",
6464
"getter": ase.Atoms.get_initial_magnetic_moments,
6565
"unit": "",
6666
},
6767
"ase::magnetic_moment": {
68-
"quantity": "",
68+
"quantity": "magnetic_moment",
6969
"getter": ase.Atoms.get_magnetic_moment,
7070
"unit": "",
7171
},
7272
"ase::magnetic_moments": {
73-
"quantity": "",
73+
"quantity": "magnetic_moment",
7474
"getter": ase.Atoms.get_magnetic_moments,
7575
"unit": "",
7676
},
7777
"ase::initial_charges": {
78-
"quantity": "",
78+
"quantity": "charge",
7979
"getter": ase.Atoms.get_initial_charges,
8080
"unit": "",
8181
},
8282
"ase::charges": {
83-
"quantity": "",
83+
"quantity": "charge",
8484
"getter": ase.Atoms.get_charges,
8585
"unit": "",
8686
},
8787
"ase::dipole_moment": {
88+
"quantity": "dipole_moment",
8889
"getter": ase.Atoms.get_dipole_moment,
8990
"unit": "",
9091
},
@@ -972,18 +973,13 @@ def _get_ase_input(
972973
dtype: torch.dtype,
973974
device: torch.device,
974975
) -> "TensorMap":
975-
if name in ARRAY_QUANTITIES:
976-
infos = ARRAY_QUANTITIES[name]
977-
if infos["quantity"] != option.quantity:
978-
raise ValueError(
979-
f"The model requested '{name}' with quantity '{option.quantity}', "
980-
f"but the quantity is '{infos['quantity']}' in `ase`."
981-
)
982-
else:
976+
if name not in ARRAY_QUANTITIES:
983977
raise ValueError(
984978
f"The model requested '{name}', which is not available in `ase`."
985979
)
986980

981+
infos = ARRAY_QUANTITIES[name]
982+
987983
values = infos["getter"](atoms)
988984
if values.shape[0] != len(atoms):
989985
raise NotImplementedError(
@@ -994,7 +990,7 @@ def _get_ase_input(
994990
# for metatensor
995991
values = torch.tensor(values[..., None])
996992

997-
tblock = TensorBlock(
993+
block = TensorBlock(
998994
values,
999995
samples=Labels(
1000996
["system", "atom"],
@@ -1005,21 +1001,16 @@ def _get_ase_input(
10051001
components=[Labels(["xyz"], torch.arange(values.shape[1]).reshape(-1, 1))]
10061002
if values.shape[1] != 1
10071003
else [],
1008-
properties=Labels(
1009-
[
1010-
name if "::" not in name else name.split("::")[1],
1011-
],
1012-
torch.tensor([[0]]),
1013-
),
1014-
)
1015-
tmap = TensorMap(
1016-
Labels(["_"], torch.tensor([[0]])),
1017-
[tblock],
1004+
properties=Labels([option.quantity], torch.tensor([[0]])),
10181005
)
1019-
tmap.set_info("quantity", option.quantity)
1020-
tmap.set_info("unit", option.unit)
1021-
tmap.to(dtype=dtype, device=device)
1022-
return tmap
1006+
1007+
tensor = TensorMap(Labels(["_"], torch.tensor([[0]])), [block])
1008+
1009+
tensor.set_info("quantity", infos["quantity"])
1010+
tensor.set_info("unit", infos["unit"])
1011+
1012+
tensor.to(dtype=dtype, device=device)
1013+
return tensor
10231014

10241015

10251016
def _ase_to_torch_data(atoms, dtype, device):

python/metatomic_torch/metatomic/torch/model.py

Lines changed: 66 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -477,34 +477,12 @@ def forward(
477477

478478
# convert systems from engine to model units
479479
with record_function("AtomisticModel::convert_units_input"):
480-
if self._capabilities.length_unit != options.length_unit:
481-
conversion = unit_conversion_factor(
482-
quantity="length",
483-
from_unit=options.length_unit,
484-
to_unit=self._capabilities.length_unit,
485-
)
486-
487-
systems = _convert_systems_units(
488-
systems,
489-
conversion,
490-
model_length_unit=self._capabilities.length_unit,
491-
system_length_unit=options.length_unit,
492-
)
493-
494-
for name, option in self._requested_inputs.items():
495-
system_unit = str(
496-
systems[0].get_data(name).get_info("unit")
497-
) # For torchscript
498-
to_unit = option.unit
499-
conversion = unit_conversion_factor(
500-
quantity=option.quantity,
501-
from_unit=system_unit,
502-
to_unit=to_unit,
503-
)
504-
505-
_convert_systems_input_units(
506-
systems, option.quantity, conversion, to_unit
507-
)
480+
systems = _convert_systems_units(
481+
systems,
482+
model_length_unit=self._capabilities.length_unit,
483+
system_length_unit=options.length_unit,
484+
requested_inputs=self._requested_inputs,
485+
)
508486

509487
# run the actual calculations
510488
with record_function("Model::forward"):
@@ -948,12 +926,19 @@ def _check_inputs(
948926

949927
def _convert_systems_units(
950928
systems: List[System],
951-
conversion: float,
952929
model_length_unit: str,
953930
system_length_unit: str,
931+
requested_inputs: Dict[str, ModelOutput],
954932
) -> List[System]:
955-
if conversion == 1.0:
956-
return systems
933+
if model_length_unit == "" or system_length_unit == "":
934+
# no conversion for positions/cell/NL
935+
conversion = 1.0
936+
else:
937+
conversion = unit_conversion_factor(
938+
quantity="length",
939+
from_unit=system_length_unit,
940+
to_unit=model_length_unit,
941+
)
957942

958943
new_systems: List[System] = []
959944
for system in systems:
@@ -978,41 +963,58 @@ def _convert_systems_units(
978963
)
979964

980965
known_data = system.known_data()
981-
if len(known_data) != 0:
982-
warnings.warn(
983-
"the model requires a different length unit "
984-
f"({model_length_unit}) than the system ({system_length_unit}), "
985-
f"but we don't know how to convert custom data ({known_data}) "
986-
"accordingly",
987-
stacklevel=2,
988-
)
966+
for name in known_data:
967+
if name not in requested_inputs:
968+
# not a requested input, just copy as is
969+
new_system.add_data(name, system.get_data(name))
989970

990-
for data in known_data:
991-
new_system.add_data(data, system.get_data(data))
971+
else:
972+
requested = requested_inputs[name]
973+
tensor = system.get_data(name)
974+
unit = tensor.get_info("unit")
975+
976+
if requested.quantity != "" and unit is not None:
977+
conversion = unit_conversion_factor(
978+
quantity=requested.quantity,
979+
from_unit=unit,
980+
to_unit=requested.unit,
981+
)
982+
else:
983+
conversion = 1.0
984+
985+
new_blocks: List[TensorBlock] = []
986+
for block in tensor.blocks():
987+
new_values = conversion * block.values
988+
new_block = TensorBlock(
989+
values=new_values,
990+
samples=block.samples,
991+
components=block.components,
992+
properties=block.properties,
993+
)
992994

993-
new_systems.append(new_system)
995+
for parameter, gradient in block.gradients():
996+
if len(gradient.gradients_list()) != 0:
997+
raise NotImplementedError(
998+
"nested gradients are not supported"
999+
)
1000+
1001+
new_gradient = TensorBlock(
1002+
values=conversion * gradient.values,
1003+
samples=gradient.samples,
1004+
components=gradient.components,
1005+
properties=gradient.properties,
1006+
)
1007+
new_block.add_gradient(parameter, new_gradient)
1008+
new_blocks.append(new_block)
9941009

995-
return new_systems
1010+
new_tensor = TensorMap(
1011+
keys=tensor.keys,
1012+
blocks=new_blocks,
1013+
)
1014+
new_tensor.set_info("unit", requested.unit)
1015+
new_tensor.set_info("quantity", requested.quantity)
1016+
new_system.add_data(name, new_tensor)
9961017

1018+
new_systems.append(new_system)
9971019

998-
def _convert_systems_input_units(
999-
systems: List[System], quantity: str, conversion: float, to_unit: str
1000-
) -> None:
1001-
if conversion != 1.0:
1002-
for system in systems:
1003-
tensor = system.get_data(quantity)
1004-
tblock = tensor.block()
1005-
new_tensor = TensorMap(
1006-
Labels("_", torch.tensor([[0]])),
1007-
[
1008-
TensorBlock(
1009-
values=conversion * tblock.values,
1010-
samples=tblock.samples,
1011-
components=tblock.components,
1012-
properties=tblock.properties,
1013-
)
1014-
],
1015-
)
1016-
new_tensor.set_info("unit", to_unit)
1017-
new_tensor.set_info("quantity", quantity)
1018-
system.add_data(quantity, new_tensor, override=True)
1020+
return new_systems

python/metatomic_torch/tests/ase_calculator.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ def test_additional_input(atoms):
837837
inputs = {
838838
"masses": ModelOutput(quantity="mass", unit="u", per_atom=True),
839839
"velocities": ModelOutput(quantity="velocity", unit="A/fs", per_atom=True),
840-
"ase::initial_charges": ModelOutput(quantity="", unit="", per_atom=True),
840+
"ase::initial_charges": ModelOutput(quantity="charge", unit="", per_atom=True),
841841
}
842842
outputs = {("extra::" + prop): inputs[prop] for prop in inputs}
843843
capabilities = ModelCapabilities(
@@ -855,17 +855,16 @@ def test_additional_input(atoms):
855855
atoms.set_initial_charges([0.0] * len(atoms))
856856
calculator = MetatomicCalculator(model, check_consistency=True)
857857
results = calculator.run_model(atoms, outputs)
858-
for k, v in results.items():
859-
head, prop = k.split("::", maxsplit=1)
858+
for name, tensor in results.items():
859+
head, name = name.split("::", maxsplit=1)
860860
assert head == "extra"
861-
assert prop in inputs
862-
assert len(v.keys.names) == 1
863-
assert v.get_info("quantity") == inputs[prop].quantity
864-
values = v[0].values.numpy()
865-
shape = values.shape
866-
assert shape[0] == len(atoms), f"Expected {len(atoms)} values, got {shape[0]}"
867-
assert np.allclose(
868-
values,
869-
ARRAY_QUANTITIES[prop]["getter"](atoms).reshape(shape)
870-
* (10 if prop == "velocity" else 1), # ase velocity is in nm/fs, not A/fs
871-
)
861+
assert name in inputs
862+
863+
assert tensor.get_info("quantity") == inputs[name].quantity
864+
values = tensor[0].values.numpy()
865+
866+
expected = ARRAY_QUANTITIES[name]["getter"](atoms).reshape(values.shape)
867+
if name == "velocities":
868+
expected *= 10.0 # ase velocity is in nm/fs
869+
870+
assert np.allclose(values, expected)

0 commit comments

Comments
 (0)