Skip to content

Commit 1a1bb57

Browse files
authored
improve: ase try to get virials from different sources (#660)
1 parent b91c598 commit 1a1bb57

File tree

1 file changed

+15
-8
lines changed

1 file changed

+15
-8
lines changed

dpdata/plugins/ase.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict:
6363
"""Convert ase.Atoms to a LabeledSystem. Energies and forces
6464
are calculated by the calculator.
6565
66+
Note that this method will try to load virials from either virial field or converted from stress tensor.
67+
6668
Parameters
6769
----------
6870
atoms : ase.Atoms
@@ -94,13 +96,19 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict:
9496
"energies": np.array([energies]),
9597
"forces": np.array([forces]),
9698
}
97-
try:
98-
stress = atoms.get_stress(voigt=False)
99-
except PropertyNotImplementedError:
100-
pass
101-
else:
102-
virials = np.array([-atoms.get_volume() * stress])
103-
info_dict["virials"] = virials
99+
100+
# try to get virials from different sources
101+
virials = atoms.info.get("virial")
102+
if virials is None:
103+
try:
104+
stress = atoms.get_stress(voigt=False)
105+
except PropertyNotImplementedError:
106+
pass
107+
else:
108+
virials = -atoms.get_volume() * stress
109+
if virials is not None:
110+
info_dict["virials"] = np.array([virials])
111+
104112
return info_dict
105113

106114
def from_multi_systems(
@@ -166,7 +174,6 @@ def to_labeled_system(self, data, *args, **kwargs) -> list[ase.Atoms]:
166174

167175
structures = []
168176
species = [data["atom_names"][tt] for tt in data["atom_types"]]
169-
170177
for ii in range(data["coords"].shape[0]):
171178
structure = Atoms(
172179
symbols=species,

0 commit comments

Comments
 (0)