@@ -63,6 +63,8 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict:
63
63
"""Convert ase.Atoms to a LabeledSystem. Energies and forces
64
64
are calculated by the calculator.
65
65
66
+ Note that this method will try to load virials from either virial field or converted from stress tensor.
67
+
66
68
Parameters
67
69
----------
68
70
atoms : ase.Atoms
@@ -94,13 +96,19 @@ def from_labeled_system(self, atoms: ase.Atoms, **kwargs) -> dict:
94
96
"energies" : np .array ([energies ]),
95
97
"forces" : np .array ([forces ]),
96
98
}
97
- try :
98
- stress = atoms .get_stress (voigt = False )
99
- except PropertyNotImplementedError :
100
- pass
101
- else :
102
- virials = np .array ([- atoms .get_volume () * stress ])
103
- info_dict ["virials" ] = virials
99
+
100
+ # try to get virials from different sources
101
+ virials = atoms .info .get ("virial" )
102
+ if virials is None :
103
+ try :
104
+ stress = atoms .get_stress (voigt = False )
105
+ except PropertyNotImplementedError :
106
+ pass
107
+ else :
108
+ virials = - atoms .get_volume () * stress
109
+ if virials is not None :
110
+ info_dict ["virials" ] = np .array ([virials ])
111
+
104
112
return info_dict
105
113
106
114
def from_multi_systems (
@@ -166,7 +174,6 @@ def to_labeled_system(self, data, *args, **kwargs) -> list[ase.Atoms]:
166
174
167
175
structures = []
168
176
species = [data ["atom_names" ][tt ] for tt in data ["atom_types" ]]
169
-
170
177
for ii in range (data ["coords" ].shape [0 ]):
171
178
structure = Atoms (
172
179
symbols = species ,
0 commit comments