@@ -63,6 +63,8 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict:
6363 """Convert ase.Atoms to a LabeledSystem. Energies and forces
6464 are calculated by the calculator.
6565
66+ Note that this method will try to load virials from either virial field or converted from stress tensor.
67+
6668 Parameters
6769 ----------
6870 atoms : ase.Atoms
@@ -94,13 +96,19 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict:
9496 "energies" : np .array ([energies ]),
9597 "forces" : np .array ([forces ]),
9698 }
97- try :
98- stress = atoms .get_stress (voigt = False )
99- except PropertyNotImplementedError :
100- pass
101- else :
102- virials = np .array ([- atoms .get_volume () * stress ])
103- info_dict ["virials" ] = virials
99+
100+ # try to get virials from different sources
101+ virials = atoms .info .get ("virial" )
102+ if virials is None :
103+ try :
104+ stress = atoms .get_stress (voigt = False )
105+ except PropertyNotImplementedError :
106+ pass
107+ else :
108+ virials = - atoms .get_volume () * stress
109+ if virials is not None :
110+ info_dict ["virials" ] = np .array ([virials ])
111+
104112 return info_dict
105113
106114 def from_multi_systems (
@@ -166,7 +174,6 @@ def to_labeled_system(self, data, *args, **kwargs) -> list[ase.Atoms]:
166174
167175 structures = []
168176 species = [data ["atom_names" ][tt ] for tt in data ["atom_types" ]]
169-
170177 for ii in range (data ["coords" ].shape [0 ]):
171178 structure = Atoms (
172179 symbols = species ,
0 commit comments